Ejemplo n.º 1
0
def projectParam_vec(param, N, D, G, M, K, lb=1e-6):
    # unpack the input parameter vector
    tp_1 = [0, M, 2*M, 3*M, 4*M, 4*M+M*G, 4*M+M*G+G, 4*M+M*G+2*G, 
            4*M+M*G+2*G+G*N*K, 4*M+M*G+2*G+G*(N+D)*K, 4*M+M*G+2*G+G*(N+2*D)*K,
            4*M+M*G+2*G+G*(N+2*D+1)*K, 4*M+M*G+2*G+G*(N+2*D+2)*K]
    tp_2 = []
    for i in np.arange(len(tp_1)-1):
        tp_2.append(param[tp_1[i] : tp_1[i+1]])
    [tau_a1, tau_a2, tau_b1, tau_b2, phi, tau_v1, tau_v2, eta, mu_w, sigma_w,\
            mu_b, sigma_b] = tp_2
    phi = np.reshape(phi, (M,G))
    eta = np.reshape(eta, (G,N,K))
    
    # apply projections
    w_tau_ab = projectLB(np.concatenate((tau_a1,tau_a2,tau_b1,tau_b2)), lb)
    w_phi = np.zeros((M,G))
    for m in np.arange(M):
        w_phi[m] = projectSimplex_vec(phi[m])
    w_tau_v = projectLB(np.concatenate((tau_v1,tau_v2)), lb)

    w_eta = np.zeros((G,N,K))
    for g in np.arange(G):
        for n in np.arange(N):
            w_eta[g,n] = projectSimplex_vec(eta[g,n])

    w = np.concatenate((w_tau_ab, w_phi.reshape(M*G), w_tau_v, \
            w_eta.reshape(G*N*K), mu_w, projectLB(sigma_w,lb), mu_b, \
            projectLB(sigma_b,lb)))
    return w
Ejemplo n.º 2
0
    def callback(params):
        print("Log likelihood {}".format(-objective(params)))
        plt.cla()
        print(params)
        # Show posterior marginals.
        plot_xs = np.reshape(np.linspace(-7, 7, 300), (300,1))
        pred_mean, pred_cov = predict(params, X, y, plot_xs)
        marg_std = np.sqrt(np.diag(pred_cov))
        ax.plot(plot_xs, pred_mean, 'b')
        ax.fill(np.concatenate([plot_xs, plot_xs[::-1]]),
                np.concatenate([pred_mean - 1.96 * marg_std,
                               (pred_mean + 1.96 * marg_std)[::-1]]),
                alpha=.15, fc='Blue', ec='None')

        # Show samples from posterior.
        rs = npr.RandomState(0)
        sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10)
        ax.plot(plot_xs, sampled_funcs.T)

        ax.plot(X, y, 'kx')
        ax.set_ylim([-1.5, 1.5])
        ax.set_xticks([])
        ax.set_yticks([])
        plt.draw()
        plt.pause(1.0/60.0)
Ejemplo n.º 3
0
    def plot_single_gp(ax, params, layer, unit, plot_xs):
        ax.cla()
        rs = npr.RandomState(0)

        deep_map = create_deep_map(params)
        gp_details = deep_map[layer][unit]
        gp_params = pack_gp_params(gp_details)

        pred_mean, pred_cov = predict_layer_funcs[layer][unit](gp_params, plot_xs, with_noise = False, FITC = False)
        x0 = deep_map[layer][unit]['x0']
        y0 = deep_map[layer][unit]['y0']
        noise_scale = deep_map[layer][unit]['noise_scale']

        marg_std = np.sqrt(np.diag(pred_cov))
        if n_samples_to_plot > 19:
            ax.plot(plot_xs, pred_mean, 'b')
            ax.fill(np.concatenate([plot_xs, plot_xs[::-1]]),
            np.concatenate([pred_mean - 1.96 * marg_std,
                           (pred_mean + 1.96 * marg_std)[::-1]]),
                           alpha=.15, fc='Blue', ec='None')

        # Show samples from posterior.
        sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov*(random), size=n_samples_to_plot)
        ax.plot(plot_xs, sampled_funcs.T)
        ax.plot(x0, y0, 'ro')
        #ax.errorbar(x0, y0, yerr = noise_scale, fmt='o')
        ax.set_xticks([])
        ax.set_yticks([])
Ejemplo n.º 4
0
def make_pinwheel_data(num_spokes=5, points_per_spoke=40, rate=1.0, noise_std=0.005):
    """Make synthetic data in the shape of a pinwheel."""
    spoke_angles = np.linspace(0, 2 * np.pi, num_spokes + 1)[:-1]
    rs = npr.RandomState(0)
    x = np.linspace(0.1, 1, points_per_spoke)
    xs = np.concatenate([x * np.cos(angle + x * rate) + noise_std * rs.randn(len(x)) for angle in spoke_angles])
    ys = np.concatenate([x * np.sin(angle + x * rate) + noise_std * rs.randn(len(x)) for angle in spoke_angles])
    return np.concatenate([np.expand_dims(xs, 1), np.expand_dims(ys, 1)], axis=1)
Ejemplo n.º 5
0
 def Lambda(self):
     lam = super(LinearDiscreteHawkes, self).Lambda
     w = self.W[self.K:, :]
     assert w.shape == (self.K, self.K), "unmatched w shape: should be KxK"
     Rs = []
     for t in range(self.conv_dyad_data.shape[2]):
         Rs.append(self.conv_dyad_data[:, :, t] * w)
     Rs = np.concatenate(Rs, axis=1).reshape(self.conv_dyad_data.shape, order='F')
     return np.concatenate([lam[i, :, :] + Rs[:, :, i] for i in range(self.data.shape[0])]).reshape(lam.shape)
Ejemplo n.º 6
0
def make_pinwheel_data(num_classes, num_per_class, rate=2.0, noise_std=0.001):
    spoke_angles = np.linspace(0, 2*np.pi, num_classes+1)[:-1]

    rs = npr.RandomState(0)
    x = np.linspace(0.1, 1, num_per_class)
    xs = np.concatenate([rate *x * np.cos(angle + x * rate) + noise_std * rs.randn(num_per_class)
                         for angle in spoke_angles])
    ys = np.concatenate([rate *x * np.sin(angle + x * rate) + noise_std * rs.randn(num_per_class)
                         for angle in spoke_angles])
    return np.concatenate([np.expand_dims(xs, 1), np.expand_dims(ys,1)], axis=1)
Ejemplo n.º 7
0
def flatten(value):
    """value can be any nesting of tuples, arrays, dicts.
       returns 1D numpy array and an unflatten function."""
    if isinstance(getval(value), np.ndarray):
        def unflatten(vector):
            return np.reshape(vector, value.shape)
        return np.ravel(value), unflatten

    elif isinstance(getval(value), float):
        return np.array([value]), lambda x : x[0]

    elif isinstance(getval(value), tuple):
        if not value:
            return np.array([]), lambda x : ()
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return (unflatten_first(vector[:N]),) + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), list):
        if not value:
            return np.array([]), lambda x : []
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), dict):
        flattened = []
        unflatteners = []
        lengths = []
        keys = []
        for k, v in sorted(iteritems(value), key=itemgetter(0)):
            cur_flattened, cur_unflatten = flatten(v)
            flattened.append(cur_flattened)
            unflatteners.append(cur_unflatten)
            lengths.append(len(cur_flattened))
            keys.append(k)

        def unflatten(vector):
            split_ixs = np.cumsum(lengths)
            pieces = np.split(vector, split_ixs)
            return {key: unflattener(piece)
                    for piece, unflattener, key in zip(pieces, unflatteners, keys)}

        return np.concatenate(flattened), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(type(value)))
Ejemplo n.º 8
0
Archivo: ml.py Proyecto: gablg1/ml-util
def ridgeData(X_train, Y_train, regularization_factor):
    # This wasn't tested for dim(X_train) != 2 or dim(Y_train) != 1
    N, D = X_train.shape
    assert(Y_train.shape == (N,))
    ridge_matrix = np.sqrt(ridge_precision) * np.identity(D)
    X_trainp = np.concatenate((X_train, ridge_matrix), 0)

    zeros = np.zeros([D for i in range(dim(Y_train))])
    assert(dim(zeros) == dim(Y_train))
    Y_trainp = np.concatenate((Y_train, zeros), 0)
    return X_trainp, Y_trainp
Ejemplo n.º 9
0
def calc_side_matrices(operators, operators_bar, obs, test_points, op_cache, fun_args=None):
    obs_points = np.r_[[p for p, _ in obs]]
    L = []
    Lbar = []
    for op, op_bar, point in zip(operators, operators_bar, obs_points):
        f = op_cache[(op,)]
        fbar = op_cache[(op_bar,)]
        L.append(f(point, test_points, fun_args))
        Lbar.append(fbar(test_points, point, fun_args))
    L = np.concatenate(L)
    Lbar = np.concatenate(Lbar, axis=1)
    return L, Lbar
def get_Aopt(inX, iny):
    X_train, y_train, X_test, y_test = ascdata.split_train_test(inX, iny)
    X_train = np.concatenate((X_train, np.ones((X_train.shape[ 0 ], 1))), 1)
    X_test = np.concatenate((X_test, np.ones((X_test.shape[ 0 ], 1))), 1)
    X_train_less, s_train = ascdata.split_X_s(X_train)
    X_test_less, s_test = ascdata.split_X_s(X_test)

    s_train_phi = ascdata.generate_phi(s_train, d, A_phi, b_phi)
    s_test_phi = ascdata.generate_phi(s_test, d, A_phi, b_phi)

    nfeatures = X_train.shape[1] - 1
    # Dimensions of phi(s)
    nfeatures_phi = d
    invT2 = 10

    def logprob(inA, inX, iny, ins_phi):
        RMS = 0
        for i in range(len(iny)):
            wi = np.dot(inA, inX[i])
            RMS_current = (iny[i] - np.dot(wi, ins_phi[i]))**2
            RMS += RMS_current
        return -RMS

    objective = lambda inA, t: -logprob(inA, X_train_less, y_train, s_train_phi)

    LLHs = []
    LLH_xs = []

    def callback(params, t, g):
        LLH = -objective(params, t)
        LLHs.append(LLH)
        LLH_xs.append(t)
        print("Iteration {} log likelihood {}".format(t, LLH))

    init_A = 0.00000000001*(np.ones((nfeatures_phi, nfeatures)))
    # init_A =  [[ -3.05236728e-04,  -9.50015728e-04,  -3.80139503e-04,   1.44010470e-04, -3.05236728e-04,
    #              -4.96117987e-04,  -1.02736409e-04,  -1.86416292e-04, -9.52628589e-04,  -1.55023279e-03,
    #              1.44717581e-04,   1.00000000e-11, -9.50028200e-04,  -4.96117987e-04,   1.00000000e-11,
    #              -3.05236728e-04, 1.77416412e-06,  -8.16665436e-06,   3.12622951e-05,  -8.25700143e-04,
    #              1.44627987e-04,   1.90211243e-05,  -8.28273186e-04,  -9.41349990e-04, -4.56671031e-04,
    #              9.79097070e-03,  -6.41866046e-04,  -7.79274856e-05, 1.44539330e-04,  -3.05236728e-04,
    #              -5.99188450e-04,  -7.29470175e-04, -6.69558174e-04,  -9.50028200e-04]]
    init_A = np.array(init_A)

    print("Optimizing network parameters...")
    optimized_params = adam(grad(objective), init_A,
                            step_size=0.01, num_iters=1000, callback=callback)

    Aopt = optimized_params
    print "Aopt = ", Aopt

    return Aopt, X_train_less, y_train, s_train, X_test_less, y_test, s_test, LLHs, LLH_xs
Ejemplo n.º 11
0
def projectParam(param, N, D, G, M, K, lb=1e-6):
    """ project variational parameter vector onto the constraint set, including
    positive constraints for parameters of Beta distributions, simplex
    constraints for parameters of Categorical distributions
    
    Parameters
    ----------
    param: length (2M + 2M + MG + 2G + GDK + GDK + GK + GK) 
        variational parameters, including:
        1) tau_a1: len(M), first parameter of q(alpha_m)
        2) tau_a2: len(M), second parameter of q(alpha_m)
        3) tau_b1: len(M), first parameter of q(beta_m)
        4) tau_b2: len(M), second parameter of q(beta_m)
        5) phi: shape(M, G), phi[m,:] is the paramter vector of q(c_m)
        6) tau_v1: len(G), first parameter of q(nu_g)
        7) tau_v2: len(G), second parameter of q(nu_g)
        8) mu_w: shape(G, D, K), mu_w[g,d,k] is the mean parameter of 
            q(W^g_{dk})
        9) sigma_w: shape(G, D, K), sigma_w[g,d,k] is the std parameter of 
            q(W^g_{dk})
        10) mu_b: shape(G, K), mu_b[g,k] is the mean parameter of q(b^g_k)
        11) sigma_b: shape(G, K), sigma_b[g,k] is the std parameter of q(b^g_k)
    N,D,G,M,K: number of samples (N), features(D), groups(G), experts(M),
        clusters(K)
    lb: float, lower bound of positive constraints
     
    Returns
    -------
    w: length (2M + 2M + MG + 2G + GNK + GDK + GDK + GK + GK) 
    """
    # unpack the input parameter vector
    tp_1 = [0, M, 2*M, 3*M, 4*M, 4*M+M*G, 4*M+M*G+G, 
            4*M+M*G+2*G, 4*M+M*G+2*G+G*D*K, 4*M+M*G+2*G+G*(2*D)*K,
            4*M+M*G+2*G+G*(2*D+1)*K, 4*M+M*G+2*G+G*(2*D+2)*K]
    tp_2 = []
    for i in np.arange(len(tp_1)-1):
        tp_2.append(param[tp_1[i] : tp_1[i+1]])
    [tau_a1, tau_a2, tau_b1, tau_b2, phi, tau_v1, tau_v2, mu_w, sigma_w,\
            mu_b, sigma_b] = tp_2
    phi = np.reshape(phi, (M,G))
     
    # apply projections
    w_tau_ab = projectLB(np.concatenate((tau_a1,tau_a2,tau_b1,tau_b2)), lb)
    
    w_phi_vec = np.reshape(projectSimplex(phi), M*G)

    w_tau_v = projectLB(np.concatenate((tau_v1,tau_v2)), lb)
    
    w = np.concatenate((w_tau_ab, w_phi_vec, w_tau_v, \
            mu_w, projectLB(sigma_w,lb), mu_b, projectLB(sigma_b,lb)))
    return w
Ejemplo n.º 12
0
    def mean(self, test_points, g=None):
        if g is None:
            g = np.concatenate([val for _, val in self.__obs])
        L, Lbar = calc_side_matrices(self.__operators, self.__operators_bar, self.__obs, test_points, self.__op_cache, self.__fun_args)
        mu_multiplier = np.dot(Lbar, self.__LLbar_inv)

        return np.dot(mu_multiplier, g)
def preprocess_data(df, nodes_nbrs, graph_idxs, graph_nodes, graph_array):
    intersect = set(df['seqid'].values).intersection(graph_idxs.keys())
    # Get a reduced list of graph_idxs.
    graph_idxs_reduced = dict()
    graph_nodes_reduced = dict()
    for g in intersect:
        graph_idxs_reduced[g] = graph_idxs[g]
        graph_nodes_reduced[g] = graph_nodes[g]
    # return intersect, graph_idxs_reduced, graph_nodes_reduced

    # Initialize a zero-matrix.
    idxs = np.concatenate([i for i in graph_idxs_reduced.values()])
    graph_arr_fin = np.zeros(shape=graph_array[idxs].shape)

    # Initialize empty maps of graph indices from the old to the new.
    nodes_oldnew = dict()  # {old_idx: new_idx}.
    nodes_newold = dict()  # {new_idx: old_idx}

    # Re-assign reduced graphs to the zero-matrix.
    curr_idx = 0
    for seqid, idxs in sorted(graph_idxs_reduced.items()):
        for idx in idxs:
            nodes_oldnew[idx] = curr_idx
            nodes_newold[curr_idx] = idx
            graph_arr_fin[curr_idx] = graph_array[idx]
            curr_idx += 1

    nodes_nbrs_fin = filter_and_reindex_nodes_nbrs(nodes_nbrs, nodes_oldnew)
    graph_idxs_fin = filter_and_reindex_graph_idxs(graph_idxs, nodes_oldnew)
    graph_nodes_fin = filter_and_reindex_graph_nodes(graph_nodes, nodes_oldnew)

    return graph_arr_fin, nodes_nbrs_fin, graph_idxs_fin, graph_nodes_fin
Ejemplo n.º 14
0
 def plot_gmm(params, ax, num_points=100):
     angles = np.expand_dims(np.linspace(0, 2*np.pi, num_points), 1)
     xs, ys = np.cos(angles), np.sin(angles)
     circle_pts = np.concatenate([xs, ys], axis=1) * 2.0
     for log_proportion, mean, chol in zip(*unpack_params(params)):
         cur_pts = mean + np.dot(circle_pts, chol)
         ax.plot(cur_pts[:, 0], cur_pts[:, 1], '-')
Ejemplo n.º 15
0
def initParam(prior, X, N, D, G, M, K, dir_param, prng):
    """ initialize variational parameters with prior parameters
    """
    
    [tpM, tpG, lb, ub] = [np.ones(M), np.ones(G), 10., 10.]
    tpR = prng.rand(2*M)
    [tau_a1, tau_a2, tau_b1, tau_b2, tau_v1, tau_v2] = \
            [lb+(ub-lb)*tpR[0 : M], tpM,\
             lb+(ub-lb)*tpR[M : 2*M], tpM, \
             tpG, tpG]

    mu_w = prng.randn(G,D,K)/np.sqrt(D)
    sigma_w = np.ones(G*D*K) * 1e-3
    mu_b = prng.randn(G, K)/np.sqrt(D)
    sigma_b = np.ones(G*K) * 1e-3

    phi = np.reshape(prng.dirichlet(np.ones(G)*dir_param, M), M*G)
    
    mu_w = np.reshape(mu_w, G*D*K)
    mu_b = np.reshape(mu_b, G*K)

    param_init = np.concatenate((tau_a1, tau_a2, tau_b1, tau_b2, phi, tau_v1,\
        tau_v2, mu_w, sigma_w, mu_b, sigma_b))
    
    return param_init
Ejemplo n.º 16
0
def handle_time_inds(times, h=None):
    """
    Takes a list of time vectors and returns the unique, potentially augmented,
    vector 
    """
    # get size of each vector
    t_sizes = [len(t) for t in times]

    # concatenate to single time vector
    tt = np.concatenate(times)

    # get the distinct times, and the indices
    tt_uni, inv_ind = np.unique(tt, return_inverse=True)

    # split inv_ind up
    ind_ti = util._unpack_vector(inv_ind, t_sizes)

    if h is None:
        return tt_uni, ind_ti

    elif isinstance(h, float) and h > 0:
        # augment the time vector so that diff is at most h
        ttc, inds_c = augment_times(tt_uni, h)
        data_inds = [inds_c[ind_i] for ind_i in ind_ti]
        return ttc, data_inds

    else:
        raise ValueError("h should be a float > 0")
Ejemplo n.º 17
0
def jacobian(fun, argnum=0):
    """
    Returns a function which computes the Jacobian of `fun` with respect to
    positional argument number `argnum`, which must be a scalar or array. Unlike
    `grad` it is not restricted to scalar-output functions, but also it cannot
    take derivatives with respect to some argument types (like lists or dicts).
    If the input to `fun` has shape (in1, in2, ...) and the output has shape
    (out1, out2, ...) then the Jacobian has shape (out1, out2, ..., in1, in2, ...).
    """
    def getshape(val):
        val = getval(val)
        assert np.isscalar(val) or isinstance(val, np.ndarray), \
            'Jacobian requires input and output to be scalar- or array-valued'
        return np.shape(val)

    def unit_vectors(shape):
        for idxs in it.product(*map(range, shape)):
            vect = np.zeros(shape)
            vect[idxs] = 1
            yield vect

    concatenate = lambda lst: np.concatenate(map(np.atleast_1d, lst))

    @attach_name_and_doc(fun, argnum, 'Jacobian')
    def jacfun(*args, **kwargs):
        vjp, ans = make_vjp(fun, argnum)(*args, **kwargs)
        outshape = getshape(ans)
        grads = map(vjp, unit_vectors(outshape))
        jacobian_shape = outshape + getshape(args[argnum])
        return np.reshape(concatenate(grads), jacobian_shape)

    return jacfun
Ejemplo n.º 18
0
 def Lambda(self):
     w = self.W
     Rs = []
     for k in range(self.K):
         wk = tile(w[:self.K, k], (self.conv_data.shape[0], 1))
         Rs.append((self.conv_data * wk)[:, None])
     return np.concatenate(Rs, axis=2).reshape((self.conv_data.shape[0], self.K, self.K), order='F')
def GenerateDataset(Size, Nfeat, Npoints, Nneurons):
    
    X, Y = [], []
    
    for i in range(Size):
        # generate random ffnn
        #P = FFNN_Parameters(Nfeat, Nneurons, 1, 20)
        x = rnd(Npoints, Nfeat)*20
        #P[2] = np.abs(P[2])
        #y = np.dot( np.tanh( np.dot(x, P[0]) + P[1]),  P[2] )
        pr = np.random.randn(x.shape[1], 1)
        pr = pr > 0
        y = np.dot(x, pr)*0.5
        xy = np.concatenate((x,y), axis=1)
        """
        W = np.vstack((P[0], [P[1]] , np.transpose( P[2] ))) #
        W = np.transpose(W)
        W = W[np.lexsort(np.fliplr(W).T)]
        W = np.transpose(W)"""
        
        xval = xy.flatten()
        yval = pr.flatten()
        
        if X == []:
            X = np.zeros((Size, xval.shape[0]))
            Y = np.zeros((Size, yval.shape[0]))
        
        X[i,:] = xval
        Y[i,:] = yval
    
    return X,Y
Ejemplo n.º 20
0
def TrackNormal(x):
    xx = np.concatenate([x[-1:], x, x[:1]])
    p0 = xx[:-2]
    p2 = xx[2:]
    T = p2 - p0  # track derivative
    uT = np.abs(T)
    return T / uT
Ejemplo n.º 21
0
def jacobian(fun, argnum=0):
    """
    Returns a function which computes the Jacobian of `fun` with respect to
    positional argument number `argnum`, which must be a scalar or array. Unlike
    `grad` it is not restricted to scalar-output functions, but also it cannot
    take derivatives with respect to some argument types (like lists or dicts).
    If the input to `fun` has shape (in1, in2, ...) and the output has shape
    (out1, out2, ...) then the Jacobian has shape (out1, out2, ..., in1, in2, ...).
    """
    dummy = lambda: None

    def getshape(val):
        val = getval(val)
        assert np.isscalar(val) or isinstance(val, np.ndarray), \
            'Jacobian requires input and output to be scalar- or array-valued'
        return np.shape(val)

    def list_fun(*args, **kwargs):
        val = fun(*args, **kwargs)
        dummy.outshape = getshape(val)
        return list(np.ravel(val))

    concatenate = lambda lst: np.concatenate(list(map(np.atleast_1d, lst)))

    @attach_name_and_doc(fun, argnum, 'Jacobian')
    def jacfun(*args, **kwargs):
        start_node, end_nodes, tape = forward_pass(list_fun, args, kwargs, argnum)
        grads = list(map(partial(backward_pass, start_node, tape=tape), end_nodes))
        shape = dummy.outshape + getshape(args[argnum])
        return np.reshape(concatenate(grads), shape) if shape else grads[0]
    return jacfun
Ejemplo n.º 22
0
 def flatten(self, value, covector=False):
     if self.shape:
         return np.concatenate(
             [s.flatten(value[k], covector)
              for k, s in sorted(iteritems(self.shape))])
     else:
         return np.zeros((0,))
Ejemplo n.º 23
0
def TrackVelocity(x, k, vmax, acmax, Ta):
    ''' compute the velocity at each point along the track (given
    already-computed curvatures) assuming a certain accelration profile '''
    v = np.minimum(np.abs(acmax / k)**0.5, vmax)

    # also compute arc distance between successive points in x given curvature
    # k; for now we'll just use the linear distance though as it's close enough
    s = np.abs(np.concatenate([x[1:] - x[:-1], x[:1] - x[-1:]]))

    va = 0
    T = 0
    vout = []

    # first pass is just to get the initial velocity
    # let's assume it's zero
    # for i in range(1, len(k)):
    #     va = va + (v[i] - va) / Ta

    for i in range(0, len(k)):
        a = (v[i] - va) / Ta  # acceleration
        dt = s[i] / (va + a/2)  # time to reach next waypoint
        va = np.minimum(va + dt * (v[i] - va) / Ta, v[i])
        T += dt
        vout.append(va)
    return np.array(vout), T
Ejemplo n.º 24
0
	def predict(self, x):
		if self.prob_func_ == "sigmoid":
			prob = (1.0 / (1.0 + np.exp(-np.dot(x, self.coef_) - self.intercept_)))[:,np.newaxis]
			prob = np.concatenate((1.0-prob, prob), axis=1)
		else: # self.prob_func_ == "softmax"
			prob = np.exp(np.dot(x, self.coef_.T) + self.intercept_)
			prob /= np.sum(prob, axis=1)[:,np.newaxis]
		return np.array([self.classes_[i] for i in np.argmax(prob, axis=1)])
Ejemplo n.º 25
0
def build_toy_dataset(D=1, n_data=20, noise_std=0.1):
    rs = npr.RandomState(0)
    inputs  = np.concatenate([np.linspace(0, 3, num=n_data/2),
                              np.linspace(6, 8, num=n_data/2)])
    targets = (np.cos(inputs) + rs.randn(n_data) * noise_std) / 2.0
    inputs = (inputs - 4.0) / 2.0
    inputs  = inputs.reshape((len(inputs), D))
    return inputs, targets
Ejemplo n.º 26
0
 def log_marginal_likelihood(params, data):
     cluster_lls = []
     for log_proportion, mean, chol in zip(*unpack_params(params)):
         cov = np.dot(chol.T, chol) + 0.000001 * np.eye(D)
         cluster_log_likelihood = log_proportion + mvn.logpdf(data, mean, cov)
         cluster_lls.append(np.expand_dims(cluster_log_likelihood, axis=0))
     cluster_lls = np.concatenate(cluster_lls, axis=0)
     return np.sum(logsumexp(cluster_lls, axis=0))
 def sample(var_mixture_params, num_samples, rs):
     """Sample locations aren't a continuous function of parameters
     due to multinomial sampling."""
     log_weights, var_params = unpack_mixture_params(var_mixture_params)
     samples = np.concatenate([component_sample(params_k, num_samples, rs)[:, np.newaxis, :]
                          for params_k in var_params], axis=1)
     ixs = np.random.choice(k, size=num_samples, p=np.exp(log_weights))
     return np.array([samples[i, ix, :] for i, ix in enumerate(ixs)])
Ejemplo n.º 28
0
    def plot_single_gp(ax, x0, y0, pred_mean, pred_cov, plot_xs):
        ax.cla()
        marg_std = np.sqrt(np.diag(pred_cov))
        if n_samples > 1:
            ax.plot(plot_xs, pred_mean, 'b')
            ax.fill(np.concatenate([plot_xs, plot_xs[::-1]]),
            np.concatenate([pred_mean - 1.96 * marg_std,
                           (pred_mean + 1.96 * marg_std)[::-1]]),
                           alpha=.15, fc='Blue', ec='None')

        # Show samples from posterior.
        rs = npr.RandomState(0)
        sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov*(random), size=n_samples)
        ax.plot(plot_xs, sampled_funcs.T)
        ax.plot(x0, y0, 'ro')
        ax.set_xticks([])
        ax.set_yticks([])
Ejemplo n.º 29
0
    def plot_gp(ax, X, y, pred_mean, pred_cov, plot_xs):
        ax.cla()
        marg_std = np.sqrt(np.diag(pred_cov))
        ax.plot(plot_xs, pred_mean, 'b')
        ax.fill(np.concatenate([plot_xs, plot_xs[::-1]]),
                np.concatenate([pred_mean - 1.96 * marg_std,
                               (pred_mean + 1.96 * marg_std)[::-1]]),
                alpha=.15, fc='Blue', ec='None')

        # Show samples from posterior.
        rs = npr.RandomState(0)
        sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10)
        ax.plot(plot_xs, sampled_funcs.T)
        ax.plot(X, y, 'kx')
        ax.set_ylim([-1.5, 1.5])
        ax.set_xticks([])
        ax.set_yticks([])
Ejemplo n.º 30
0
    def x_with_bias(x):
        """Add a row of vector which value=1 for the bias to the input X

               x.shape=(batch_size, input_vector_length) 
            => x_with_bias(x).shape=(batch_size, input_vector_length + 1)
        """
        batch_size = x.shape[0]
        return np.concatenate((x, np.ones([batch_size, 1])), axis=1)
Ejemplo n.º 31
0
def fit_weights_and_save(
        weights_file,
        ca_data_file='rs_vm_denoise_200605.npy',
        opto_silencing_data_file='vip_halo_data_for_sim.npy',
        opto_activation_data_file='vip_chrimson_data_for_sim.npy',
        constrain_wts=None,
        allow_var=True,
        fit_s02=True,
        constrain_isn=True,
        tv=False,
        l2_penalty=0.01,
        init_noise=0.1,
        init_W_from_lsq=False,
        scale_init_by=1,
        init_W_from_file=False,
        init_file=None,
        correct_Eta=False,
        init_Eta_with_s02=False,
        init_Eta12_with_dYY=False,
        use_opto_transforms=False,
        share_residuals=False,
        stimwise=False,
        simulate1=True,
        simulate2=False,
        help_constrain_isn=True,
        ignore_halo_vip=False,
        verbose=True,
        free_amplitude=False,
        norm_opto_transforms=False,
        zero_extra_weights=None,
        no_halo_res=False,
        l23_as_l4=False):

    nsize, ncontrast = 6, 6

    npfile = np.load(ca_data_file, allow_pickle=True)[(
    )]  #,{'rs':rs,'rs_denoise':rs_denoise},allow_pickle=True)
    rs = npfile['rs']
    #rs_denoise = npfile['rs_denoise']

    nsize, ncontrast, ndir = 6, 6, 8
    #ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    ori_dirs = [[0, 1, 2, 3, 4, 5, 6, 7]]
    nT = len(ori_dirs)
    nS = len(rs[0])

    def sum_to_1(r):
        R = r.reshape((r.shape[0], -1))
        #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        R = R / np.nansum(R, axis=1)[:, np.newaxis]  # changed 8/28
        return R

    def norm_to_mean(r):
        R = r.reshape((r.shape[0], -1))
        R = R / np.nanmean(R[:, ~np.isnan(R.sum(0))], axis=1)[:, np.newaxis]
        return R

    Rs = [[None, None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]

    for iR, r in enumerate(rs):  #rs_denoise):
        #print(iR)
        for ialign in range(nS):
            #Rs[iR][ialign] = r[ialign][:,:nsize,:]
            #sm = np.nanmean(np.nansum(np.nansum(Rs[iR][ialign],1),1))
            #Rs[iR][ialign] = Rs[iR][ialign]/sm
            #print('frac isnan Rs %d,%d: %f'%(iR,ialign,np.isnan(r[ialign]).mean()))
            Rs[iR][ialign] = sum_to_1(r[ialign][:, :nsize, :])
    #         Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir)))

    kernel = np.ones((1, 2, 2))
    kernel = kernel / kernel.sum()

    for iR, r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                #print('this Rs shape: '+str(Rs[iR][ialign].shape))
                #print('this Rs reshaped shape: '+str(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]].shape))
                #print('this Rs max percent nan: '+str(np.isnan(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]]).mean(-1).max()))
                Rso[iR][ialign][iori] = np.nanmean(
                    Rs[iR][ialign].reshape(
                        (-1, nsize, ncontrast, ndir))[:, :, :, ori_dirs[iori]],
                    -1)
                Rso[iR][ialign][iori][:, :, 0] = np.nanmean(
                    Rso[iR][ialign][iori][:, :, 0],
                    1)[:, np.newaxis]  # average 0 contrast values
                #print('frac isnan pre-conv Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean()))
                Rso[iR][ialign][iori][:, 1:, 1:] = ssi.convolve(
                    Rso[iR][ialign][iori], kernel, 'valid')
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(
                    Rso[iR][ialign][iori].shape[0], -1)
                #print('frac isnan Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean()))
                #print('sum of Rso isnan: '+str(np.isnan(Rso[iR][ialign][iori]).sum(1)))
                #Rso[iR][ialign][iori] = Rso[iR][ialign][iori]/np.nanmean(Rso[iR][ialign][iori],-1)[:,np.newaxis]

    def set_bound(bd, code, val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val

    nN = 36
    nS = 2
    nP = 2
    nT = 1
    nQ = 4

    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained

    Wmx_bounds = 3 * np.ones((nP, nQ), dtype=int)
    Wmx_bounds[0, :] = 2  # L4 PCs are excitatory
    Wmx_bounds[0, 1] = 0  # SSTs don't receive L4 input

    if allow_var:
        Wsx_bounds = 3 * np.ones(
            Wmx_bounds.shape)  #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
        Wsx_bounds[0, 1] = 0
    else:
        Wsx_bounds = np.zeros(
            Wmx_bounds.shape)  #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)

    Wmy_bounds = 3 * np.ones((nQ, nQ), dtype=int)
    Wmy_bounds[0, :] = 2  # PCs are excitatory
    Wmy_bounds[1:, :] = -2  # all the cell types except PCs are inhibitory
    Wmy_bounds[1, 1] = 0  # SSTs don't inhibit themselves
    # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    Wmy_bounds[
        2,
        0] = 0  # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if not zero_extra_weights is None:
        Wmx_bounds[zero_extra_weights[0]] = 0
        Wmy_bounds[zero_extra_weights[1]] = 0

    if allow_var:
        Wsy_bounds = 3 * np.ones(
            Wmy_bounds.shape)  #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)
        Wsy_bounds[1, 1] = 0
        Wsy_bounds[3, 1] = 0
        Wsy_bounds[2, 0] = 0
    else:
        Wsy_bounds = np.zeros(
            Wmy_bounds.shape)  #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)

    if not constrain_wts is None:
        for wt in constrain_wts:
            Wmy_bounds[wt[0], wt[1]] = 0
            Wsy_bounds[wt[0], wt[1]] = 0

    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS * nT)],
                             axis=0)[np.newaxis, :]
        tiled = np.concatenate([row for irow in range(nN)], axis=0)
        return tiled

    def set_bounds_by_code(lb, ub, bdlist):
        set_bound(lb, [bd == 0 for bd in bdlist], val=0)
        set_bound(ub, [bd == 0 for bd in bdlist], val=0)

        set_bound(lb, [bd == 2 for bd in bdlist], val=0)

        set_bound(ub, [bd == -2 for bd in bdlist], val=0)

        set_bound(lb, [bd == 1 for bd in bdlist], val=1)
        set_bound(ub, [bd == 1 for bd in bdlist], val=1)

        set_bound(lb, [bd == 1.5 for bd in bdlist], val=0)
        set_bound(ub, [bd == 1.5 for bd in bdlist], val=1)

        set_bound(lb, [bd == -1 for bd in bdlist], val=-1)
        set_bound(ub, [bd == -1 for bd in bdlist], val=-1)

    if fit_s02:
        s02_bounds = 2 * np.ones(
            (nQ, ))  # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ, ))

    k_bounds = 1.5 * np.ones((nQ * (nS - 1), ))

    #k_bounds[1] = 0 # temporary: spatial kernel constrained to 0 for SST
    #k_bounds[2] = 0 # temporary: spatial kernel constrained to 0 for VIP

    kappa_bounds = np.ones((1, ))
    # kappa_bounds = 2*np.ones((1,))

    T_bounds = 1.5 * np.ones((nQ * (nT - 1), ))

    X_bounds = tile_nS_nT_nN(np.array([2, 1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)

    Xp_bounds = tile_nS_nT_nN(np.array([3, 1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)

    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))

    Eta_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))

    if allow_var:
        Xi_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, )))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))

    h1_bounds = -2 * np.ones((1, ))

    h2_bounds = 2 * np.ones((1, ))

    bl_bounds = 3 * np.ones((nQ, ))

    if free_amplitude:
        amp_bounds = 2 * np.ones((nT * nS * nQ, ))
    else:
        amp_bounds = 1 * np.ones((nT * nS * nQ, ))

    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ),
               (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
               (1, ), (1, ), (nQ, ), (nQ * nS * nT, )]
    shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
               (nN, nT * nS * nQ), (nN, nT * nS * nP), (nN, nT * nS * nP)]
    #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
    #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2])))
    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   h1, h2
    #XX,            XXp,          Eta,          Xi

    #bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h1_bounds,h2_bounds]
    bd1list = [
        Wmx_bounds, Wmy_bounds, Wsx_bounds, Wsy_bounds, s02_bounds, k_bounds,
        kappa_bounds, T_bounds, h1_bounds, h2_bounds, bl_bounds, amp_bounds
    ]
    bd2list = [X_bounds, Xp_bounds, Eta_bounds, Xi_bounds, X_bounds, X_bounds]

    lb1, ub1 = [[sgn * np.inf * np.ones(shp) for shp in shapes1]
                for sgn in [-1, 1]]
    set_bounds_by_code(lb1, ub1, bd1list)
    lb2, ub2 = [[sgn * np.inf * np.ones(shp) for shp in shapes2]
                for sgn in [-1, 1]]
    set_bounds_by_code(lb2, ub2, bd2list)

    #set_bound(lb,[bd==0 for bd in bdlist],val=0)
    #set_bound(ub,[bd==0 for bd in bdlist],val=0)
    #
    #set_bound(lb,[bd==2 for bd in bdlist],val=0)
    #
    #set_bound(ub,[bd==-2 for bd in bdlist],val=0)
    #
    #set_bound(lb,[bd==1 for bd in bdlist],val=1)
    #set_bound(ub,[bd==1 for bd in bdlist],val=1)
    #
    #set_bound(lb,[bd==1.5 for bd in bdlist],val=0)
    #set_bound(ub,[bd==1.5 for bd in bdlist],val=1)
    #
    #set_bound(lb,[bd==-1 for bd in bdlist],val=-1)
    #set_bound(ub,[bd==-1 for bd in bdlist],val=-1)

    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0

    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb1 = np.concatenate([a.flatten() for a in lb1])
    ub1 = np.concatenate([b.flatten() for b in ub1])
    lb2 = np.concatenate([a.flatten() for a in lb2])
    ub2 = np.concatenate([b.flatten() for b in ub2])
    bounds1 = [(a, b) for a, b in zip(lb1, ub1)]
    bounds2 = [(a, b) for a, b in zip(lb2, ub2)]

    nS = 2
    #print('nT: '+str(nT))
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    mx = [None for iS in range(nS)]
    for iS in range(nS):
        mx[iS] = np.zeros((ncelltypes, ))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0], 0)
            mx[iS][icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [
                np.nanmean(Rso[icelltype][iS][iT], axis=0)[:, np.newaxis] /
                mx[iS][icelltype] for icelltype in range(1, ncelltypes)
            ]
            Ypc_list[iS][iT] = [None for icelltype in range(1, ncelltypes)]
            for icelltype in range(1, ncelltypes):
                rss = Rso[icelltype][iS][iT].copy(
                )  #/mx[iS][icelltype] #.reshape(Rs[icelltype][ialign].shape[0],-1)
                #print('sum of isnan: '+str(np.isnan(rss).sum(1)))
                #rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1) == 0]
                #         print(rss.max())
                #         rss[rss<0] = 0
                #         rss = rss[np.random.randn(rss.shape[0])>0]
                try:
                    u, s, v = np.linalg.svd(rss - np.mean(rss, 0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype - 1] = [
                        (s[idim], v[idim]) for idim in range(ndims)
                    ]
    #                 print('yep on Y')
    #                 print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
                except:
                    print('nope on Y')
                    #print('shape of rss: '+str(rss.shape))
                    #print('mean of rss: '+str(np.mean(np.isnan(rss))))
                    #print('min of this rs: '+str(np.min(np.sum(rs[icelltype][iS][iT],axis=1))))
            Yhat[iS][iT] = np.concatenate(y, axis=1)
            #         x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis]
            icelltype = 0
            #x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]#/mx[iS][icelltype]
            x = np.nanmean(Rso[icelltype][iS][iT],
                           0)[:, np.newaxis] / mx[iS][icelltype]
            #         opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis]
            Xhat[iS][iT] = np.concatenate((x, np.ones_like(x)), axis=1)
            #         Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1)
            icelltype = 0
            #rss = Rso[icelltype][iS][iT].copy()/mx[iS][icelltype]
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1) == 0]
            #         try:
            u, s, v = np.linalg.svd(rss - rss.mean(0)[np.newaxis])
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim], v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0, np.zeros((Xhat[0][0].shape[0], )))
                                   for idim in range(ndims)]
    #         except:
    #             print('nope on X')
    #             print(np.mean(np.isnan(rss)))
    #             print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1)))
    nN, nP = Xhat[0][0].shape
    #print('nP: '+str(nP))
    nQ = Yhat[0][0].shape[1]

    def compute_f_(Eta, Xi, s02):
        return sim_utils.f_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)]))

    def compute_fprime_m_(Eta, Xi, s02):
        return sim_utils.fprime_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02
                                         for ipixel in range(nS * nT)])) * Xi

    def compute_fprime_s_(Eta, Xi, s02):
        s2 = Xi**2 + np.concatenate((s02, s02), axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2)

    def sorted_r_eigs(w):
        drW, prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds], prW[:, srtinds]

    #         0.Wmx,  1.Wmy,  2.Wsx,  3.Wsy,  4.s02,5.K,  6.kappa,7.T,8.XX,        9.XXp,        10.Eta,       11.Xi,   12.h1,  13.h2

    shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ),
               (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
               (1, ), (1, ), (nQ, ), (nT * nS * nQ, )]
    shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
               (nN, nT * nS * nQ), (nN, nT * nS * nP), (nN, nT * nS * nP)]
    #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
    #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2])))

    import calnet.fitting_spatial_feature
    import sim_utils

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)

    opto_dict = np.load(opto_silencing_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)),
                           3).reshape((nN * 2, -1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto / np.nanmax(Yhat_opto[0::2], 0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]

    Xhat_opto = opto_dict['Xhat_opto']
    Xhat_opto = np.nanmean(np.reshape(Xhat_opto, (nN, 2, nS, 2, nP)),
                           3).reshape((nN * 2, -1))
    Xhat_opto[0::12] = np.nanmean(Xhat_opto[0::12], axis=0)[np.newaxis]
    Xhat_opto[1::12] = np.nanmean(Xhat_opto[1::12], axis=0)[np.newaxis]
    Xhat_opto = Xhat_opto / np.nanmax(Xhat_opto[0::2], 0)[np.newaxis, :]

    YYhat_halo = Yhat_opto.reshape((nN, 2, -1))
    opto_transform1 = calnet.utils.fit_opto_transform(
        YYhat_halo, norm01=norm_opto_transforms)

    if l23_as_l4:
        Xhat_opto[:, 0::2] = Yhat_opto[:, 0::4]
    Xhat_halo = Xhat_opto.reshape((nN, 2, -1))
    opto_transform1x = calnet.utils.fit_opto_transform(
        Xhat_halo, norm01=norm_opto_transforms)

    if no_halo_res:
        opto_transform1.res[:, [0, 2, 3, 4, 6, 7]] = 0
        opto_transform1x.res[:, :] = 0

    dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat)
    dXX1 = opto_transform1x.transform(XXhat) - opto_transform1x.preprocess(
        XXhat)
    print('delta bias: %f' % dXX1[:, 1].mean())

    #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo)
    #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:]

    def overwrite_plus_n(arr, to_overwrite, n):
        arr[:, to_overwrite] = arr[:, int(to_overwrite + n)]
        return arr

    for to_overwrite in [1, 2]:
        n = 4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [7]:
        n = -4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [2]:
        n = -2
        dXX1,opto_transform1x.slope,opto_transform1x.intercept,opto_transform1x.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dXX1,opto_transform1x.slope,opto_transform1x.intercept,opto_transform1x.res]]

    if ignore_halo_vip:
        dYY1[:, 2::nQ] = np.nan

    #for to_overwrite in [1,2]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+4]
    #for to_overwrite in [7]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite-4]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]
    #for to_overwrite in [1,2,5,6]: # overwrite sst and vip with off-centered values
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+8]
    #for to_overwrite in [11,15]:
    #    dYY1[:,to_overwrite] = np.nan #dYY1[:,to_overwrite-8]

    opto_dict = np.load(opto_activation_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)),
                           3).reshape((nN * 2, -1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_chrimson = Yhat_opto.reshape((nN, 2, -1))
    opto_transform2 = calnet.utils.fit_opto_transform(
        YYhat_chrimson, norm01=norm_opto_transforms)

    Xhat_opto = np.nan * np.ones((Yhat_opto.shape[0], nP * nS * nT))
    Xhat_opto[:, 1::2] = 1
    if l23_as_l4:
        Xhat_opto[:, 0::2] = Yhat_opto[:, 0::4]

    Xhat_chrimson = Xhat_opto.reshape((nN, 2, -1))
    opto_transform2x = calnet.utils.fit_opto_transform(
        Xhat_chrimson, norm01=norm_opto_transforms)

    dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat)
    dXX2 = opto_transform2x.transform(XXhat) - opto_transform2x.preprocess(
        XXhat)

    #YYhat_chrimson_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_chrimson)
    #dYY2 = YYhat_chrimson_sim[:,1,:] - YYhat_chrimson_sim[:,0,:]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    #print('dYY1 mean: %03f'%np.nanmean(np.abs(dYY1)))
    #print('dYY2 mean: %03f'%np.nanmean(np.abs(dYY2)))

    dYY = np.concatenate((dYY1, dYY2), axis=0)
    dXX = np.concatenate((dXX1, dXX2), axis=0)

    #titles = ['VIP silencing','VIP activation']
    #for itype in [0,1,2,3]:
    #    plt.figure(figsize=(5,2.5))
    #    for iyy,dyy in enumerate([dYY1,dYY2]):
    #        plt.subplot(1,2,iyy+1)
    #        if np.sum(np.isnan(dyy[:,itype]))==0:
    #            sca.scatter_size_contrast(YYhat[:,itype],YYhat[:,itype]+dyy[:,itype],nsize=6,ncontrast=6)#,mn=0)
    #        plt.title(titles[iyy])
    #        plt.xlabel('cell type %d event rate, \n light off'%itype)
    #        plt.ylabel('cell type %d event rate, \n light on'%itype)
    #        ut.erase_top_right()
    #    plt.tight_layout()
    #    ut.mkdir('figures')
    #    plt.savefig('figures/scatter_light_on_light_off_target_celltype_%d.eps'%itype)

    opto_mask = ~np.isnan(dYY)
    opto_maskX = ~np.isnan(dXX)

    #dYY[nN:][~opto_mask[nN:]] = -dYY[:nN][~opto_mask[nN:]]

    #print('mean of opto_mask: '+str(opto_mask.mean()))

    #dYY[~opto_mask] = 0
    def zero_nans(arr):
        arr[np.isnan(arr)] = 0
        return arr

    #dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #        opto_transform2.slope,opto_transform2.intercept,opto_transform2.res\
    #        = [zero_nans(x) for x in \
    #                [dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #                opto_transform2.slope,opto_transform2.intercept,opto_transform2.res]]
    dYY = zero_nans(dYY)
    dXX = zero_nans(dXX)

    # for cell types that were not measured with chrimson, fill with values inferred from halo data (this shouldn't matter, as these entries are masked by opto_mask)
    to_adjust = np.logical_or(np.isnan(opto_transform2.slope[0]),
                              np.isnan(opto_transform2.intercept[0]))

    opto_transform2.slope[:,
                          to_adjust] = 1 / opto_transform1.slope[:, to_adjust]
    opto_transform2.intercept[:,
                              to_adjust] = -opto_transform1.intercept[:,
                                                                      to_adjust] / opto_transform1.slope[:,
                                                                                                         to_adjust]
    opto_transform2.res[:,
                        to_adjust] = -opto_transform1.res[:,
                                                          to_adjust] / opto_transform1.slope[:,
                                                                                             to_adjust]

    #np.save('/Users/dan/Documents/notebooks/mossing-PC/shared_data/calnet_data/dYY.npy',dYY)

    from importlib import reload
    reload(calnet)
    #reload(calnet.fitting_2step_spatial_feature_opto_tight_nonlinear)
    reload(sim_utils)
    # reload(calnet.fitting_spatial_feature)
    # W0list = [np.ones(shp) for shp in shapes]
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 3
    wt_dict['Eta'] = 3  # 1 #
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN, 1))  #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0.  #30.0 #0.1
    wt_dict['opto'] = 1  #1e1
    wt_dict['isn'] = 0.3
    wt_dict['tv'] = 1
    spont_frac = 0.5
    pc_frac = 0.5
    wt_dict['stimsOpto'] = (1 - spont_frac) * 6 / 5 * np.ones((nN, 1))
    wt_dict['stimsOpto'][0::6] = spont_frac * 6
    wt_dict['celltypesOpto'] = (1 - pc_frac) * 4 / 3 * np.ones(
        (1, nQ * nS * nT))
    wt_dict['celltypesOpto'][0, 0::nQ] = pc_frac * 4
    wt_dict['dirOpto'] = np.array((1, 0.3))
    wt_dict['dYY'] = 10  #10
    wt_dict['dXX'] = 10  #10
    wt_dict['coupling'] = 1e-3
    wt_dict['smi'] = 0.1
    wt_dict['smi_halo'] = 30
    wt_dict['smi_chrimson'] = 0.1

    ##temporary no_opto
    wt_dict['opto'] = 0.01  #0
    wt_dict['dirOpto'] = np.array((1, 1))
    wt_dict['stimsOpto'] = np.ones((nN, 1))
    wt_dict['celltypesOpto'] = np.ones((1, nQ * nS * nT))
    wt_dict['smi'] = 0  #0.01 # 0
    wt_dict['smi_halo'] = 0  #1 # 0
    wt_dict['smi_chrimson'] = 0  #0.01 # 0
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 0.1
    wt_dict['X'] = 3
    #wt_dict['Eta'] = 300 # 1 #

    ## temporary opto from no_opto
    #wt_dict['opto'] = 0.01
    #wt_dict['tv'] = 0.3#0.1

    np.save(
        'XXYYhat.npy', {
            'YYhat': YYhat,
            'XXhat': XXhat,
            'rs': rs,
            'Rs': Rs,
            'Rso': Rso,
            'Ypc_list': Ypc_list,
            'Xpc_list': Xpc_list
        })
    Eta0 = invert_f_mt(YYhat)

    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   h1, h2
    #XX,            XXp,          Eta,          Xi,     XX1,    XX2

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10 / dt))  #int(1e4)
    perturbation_size = 5e-2
    # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5
    #l2_penalty = 0.1
    W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper, ntries))
    is_neg = np.array([b[1] for b in bounds1]) == 0
    counter = 0
    negatize = [np.zeros(shp, dtype='bool') for shp in shapes1]
    #print(shapes1)
    for ishp, shp in enumerate(shapes1):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter + nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            #print((ihyper,itry))
            W10list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes1
            ]
            W20list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes2
            ]
            #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
            #print('size of w10: '+str(np.sum([np.size(x) for x in W10list])))
            #print('len(W10list) : '+str(len(W10list)))
            counter = 0
            for ishp, shp in enumerate(shapes1):
                W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]]
            W10list[4] = np.ones(shapes1[4])  # s02
            W10list[5] = np.ones(shapes1[5])  # K
            W10list[6] = np.ones(shapes1[6])  # kappa
            W10list[7] = np.ones(shapes1[7])  # T
            W20list[0] = np.concatenate(Xhat, axis=1)  #XX
            W20list[1] = np.zeros_like(W20list[1])  #XXp
            W20list[2] = Eta0.copy()  #np.zeros(shapes[10]) #Eta
            W20list[3] = np.zeros(shapes2[3])  #Xi
            W20list[4] = np.concatenate(Xhat, axis=1)  #XX
            W20list[5] = np.concatenate(Xhat, axis=1)  #XX
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
            if init_W_from_lsq:
                W10list[0], W10list[1] = initialize_W(Xhat,
                                                      Yhat,
                                                      scale_by=scale_init_by)
                for ivar in range(0, 2):
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)
            if constrain_isn:
                W10list[1][0, 0] = 3
                if help_constrain_isn:
                    W10list[1][0, 3] = 5
                    W10list[1][3, 0] = -5
                    W10list[1][3, 3] = -5
                else:
                    W10list[1][0, 1:4] = 5
                    W10list[1][1:4, 0] = -5

            if init_W_from_file:
                npyfile = np.load(init_file, allow_pickle=True)[()]

                #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
                #XX,XXp,Eta,Xi,XX1,XX2 = parse_W2(W2)
                #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,XX1,XX2,h1,h2,bl,amp = parse_W1(W1)
                if len(npyfile['as_list']) == 18:
                    W10list = [
                        npyfile['as_list'][ivar]
                        for ivar in [0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 16, 17]
                    ]
                    W20list = [
                        npyfile['as_list'][ivar]
                        for ivar in [8, 9, 10, 11, 12, 13]
                    ]
                elif len(npyfile['as_list']) == 16:
                    W10list = [
                        npyfile['as_list'][ivar]
                        for ivar in [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
                    ]
                    W20list = [
                        npyfile['as_list'][ivar]
                        for ivar in [8, 9, 10, 11, 8, 8]
                    ]
                if W20list[0].size == nN * nS * 2 * nP:
                    #assert(True==False)
                    W10list[7] = np.array(())
                    W10list[1][1, 0] = W10list[1][1, 0]
                    W20list[0] = np.nanmean(
                        W20list[0].reshape((nN, nS, 2, nP)), 2).flatten()  #XX
                    W20list[1] = np.nanmean(
                        W20list[1].reshape((nN, nS, 2, nP)), 2).flatten()  #XXp
                    W20list[2] = np.nanmean(
                        W20list[2].reshape((nN, nS, 2, nQ)), 2).flatten()  #Eta
                    W20list[3] = np.nanmean(
                        W20list[3].reshape((nN, nS, 2, nQ)), 2).flatten()  #Xi
                    W20list[4] = np.nanmean(
                        W20list[4].reshape((nN, nS, 2, nP)), 2).flatten()  #XX1
                    W20list[5] = np.nanmean(
                        W20list[5].reshape((nN, nS, 2, nP)), 2).flatten()  #XX2
                if correct_Eta:
                    #assert(True==False)
                    W20list[2] = Eta0.copy()
                if len(W10list) < len(shapes1):
                    #assert(True==False)
                    W10list = W10list + [
                        np.array(1),
                        np.zeros((nQ, )),
                        np.zeros((nT * nS * nQ, ))
                    ]  # add h2, bl, amp
                if init_Eta_with_s02:
                    #assert(True==False)
                    s02 = W10list[4].copy()
                    Eta0 = invert_f_mt_with_s02(YYhat, s02, nS=nS, nT=nT)
                    W20list[2] = Eta0.copy()
                #if init_Eta12_with_dYY:
                #    Eta0 = W20list[2].copy().reshape((nN,nQ*nS*nT))
                #    Xi0 = W20list[3].copy().reshape((nN,nQ*nS*nT))
                #    s020 = W10list[4].copy()
                #    YY0s = compute_f_(Eta0,Xi0,s020)
                #titles = ['VIP silencing','VIP activation']
                #for itype in [0,1,2,3]:
                #    plt.figure(figsize=(5,2.5))
                #    for iyy,yy in enumerate([YY10s,YY20s]):
                #        plt.subplot(1,2,iyy+1)
                #        if np.sum(np.isnan(yy[:,itype]))==0:
                #            sca.scatter_size_contrast(YY0s[:,itype],yy[:,itype],nsize=6,ncontrast=6)#,mn=0)
                #        plt.title(titles[iyy])
                #        plt.xlabel('cell type %d event rate, \n light off'%itype)
                #        plt.ylabel('cell type %d event rate, \n light on'%itype)
                #        ut.erase_top_right()
                #    plt.tight_layout()
                #    ut.mkdir('figures')
                #    plt.savefig('figures/scatter_light_on_light_off_init_celltype_%d.eps'%itype)
                for ivar in [0, 1, 4, 5]:  # Wmx, Wmy, s02, k
                    print(init_noise)
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)

            #print('size of bounds1: '+str(np.sum([np.size(x) for x in bd1list])))
            #print('size of w10: '+str(np.sum([np.size(x) for x in W10list])))
            #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
            W1t[ihyper][itry], W2t[ihyper][itry], loss[ihyper][
                itry], gr, hess, result = calnet.fitting_2step_opto_layers.fit_W_sim(
                    Xhat,
                    Xpc_list,
                    Yhat,
                    Ypc_list,
                    pop_rate_fn=sim_utils.f_miller_troyer,
                    pop_deriv_fn=sim_utils.fprime_miller_troyer,
                    neuron_rate_fn=sim_utils.evaluate_f_mt,
                    W10list=W10list.copy(),
                    W20list=W20list.copy(),
                    bounds1=bounds1,
                    bounds2=bounds2,
                    niter=niter,
                    wt_dict=wt_dict,
                    l2_penalty=l2_penalty,
                    compute_hessian=False,
                    dt=dt,
                    perturbation_size=perturbation_size,
                    dYY=dYY,
                    dXX=dXX,
                    constrain_isn=constrain_isn,
                    tv=tv,
                    opto_mask=opto_mask,
                    opto_maskX=opto_maskX,
                    use_opto_transforms=use_opto_transforms,
                    opto_transform1=opto_transform1,
                    opto_transform1x=opto_transform1x,
                    opto_transform2=opto_transform2,
                    opto_transform2x=opto_transform2x,
                    share_residuals=share_residuals,
                    stimwise=stimwise,
                    simulate1=simulate1,
                    simulate2=simulate2,
                    verbose=verbose)

    #def parse_W(W):
    #    Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = W
    #    return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2
    def parse_W1(W):
        Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = W
        return Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp

    def parse_W2(W):
        XX, XXp, Eta, Xi, XX1, XX2 = W
        return XX, XXp, Eta, Xi, XX1, XX2

    itry = 0
    Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = parse_W1(W1t[0][0])
    XX, XXp, Eta, Xi, XX1, XX2 = parse_W2(W2t[0][0])

    labels1 = [
        'Wmx', 'Wmy', 'Wsx', 'Wsy', 's02', 'K', 'kappa', 'T', 'h1', 'h2', 'bl',
        'amp'
    ]
    labels2 = ['XX', 'XXp', 'Eta', 'Xi', 'XX1', 'XX2']
    Wstar_dict = {}
    for i, label in enumerate(labels1):
        Wstar_dict[label] = W1t[0][0][i]
    for i, label in enumerate(labels2):
        Wstar_dict[label] = W2t[0][0][i]
    Wstar_dict['as_list'] = [
        Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, XX1, XX2, h1,
        h2, bl, amp
    ]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file, Wstar_dict, allow_pickle=True)
Ejemplo n.º 32
0
 def make_L_col(i):
     nelems = d - i - 1
     col = np.concatenate(
         (np.zeros(i + 1), icf[constructL.Lparamidx:(constructL.Lparamidx + nelems)]))
     constructL.Lparamidx += nelems
     return col
Ejemplo n.º 33
0
    def wake(self, wake_data, it):

        ddc = self.reg
        self.wake_data = wake_data.copy()

        gs = []

        nl_obs = self.nlayer - 1

        mean_name_higher = "mx%d_x%d" % (self.nlayer - 2, nl_obs)
        fun_name = "x%d->x%d" % (nl_obs, self.nlayer - 2)
        self.wake_data[mean_name_higher] = ddc.predict(self.wake_data,
                                                       fun_name)

        if self.layer_plastic[-1]:

            A = self.reg.Ws["A"]
            B = self.reg.Ws["B"]

            z_mean = self.wake_data["mx%d_x%d" % (nl_obs - 1, nl_obs)]
            x_suff = self.model.dists[-1].suff(self.wake_data["x%d" % nl_obs])
            x = np.einsum("ij,ik->ijk", z_mean,
                          x_suff).reshape(z_mean.shape[0], -1)
            y = z_mean

            dnatsuff = x.dot(A)
            dnorm = y.dot(B)
            g = dnatsuff - dnorm

            self.wake_data["dnatsuff%d" % nl_obs] = dnatsuff
            self.wake_data["dnorm%d" % nl_obs] = dnorm
            self.wake_data["dlogp%d" % (nl_obs)] = g
            gs.insert(0, g.mean(0))

        else:
            gs.insert(0, np.zeros_like(self.model.dists[-1].ps))

        if self.nlayer > 1:

            for i in range(self.nlayer - 2, 0, -1):

                mean_name_lower = mean_name_higher
                mean_name_higher = "mx%d_x%d" % (i - 1, nl_obs)
                fun_name = "mx%d_x%d->x%d" % (i, nl_obs, i - 1)
                self.wake_data[mean_name_higher] = ddc.predict(
                    self.wake_data, fun_name)

                if self.layer_plastic[i]:

                    grad_name = "x%d->dnatsuff%d" % (i, i)
                    dnatsuff = self.approx_E(mean_name_lower, grad_name)
                    self.wake_data["dnatsuff%d" % i] = dnatsuff
                    #dnatsuff = self.model.dists[i].dnatsuff(self.wake_data["x%d"%(i-1)], self.wake_data["x%d"%(i)])

                    grad_name = "x%d->dnorm%d" % (i - 1, i)
                    dnorm = self.approx_E(mean_name_higher, grad_name)
                    self.wake_data["dnorm%d" % i] = dnorm
                    #dnorm    = self.model.dists[i].dnorm(self.wake_data["x%d"%(i-1)])

                    g = (dnatsuff - dnorm)
                    #g2 = (self.model.dists[i].dlogp(self.wake_data["x%d"%(i-1)], self.wake_data["x%d"%(i)])).mean(0)
                    #assert np.allclose(g, g2)
                    self.wake_data["dlogp%d" % i] = g
                    gs.insert(0, g.mean(0))

                else:
                    gs.insert(0, np.zeros_like(self.model.dists[i].ps))

            if self.layer_plastic[0]:

                grad_name = "x0->fsuff0"
                suff = self.approx_E(mean_name_higher, grad_name)
                #suff = self.model.dists[0].suff(self.wake_data["x0"])

                dnat = self.model.dists[0].dnat()
                dnatsuff = self.model.dists[0].dnatsuff_from_dnatsuff(
                    dnat, suff)
                self.wake_data["dnatsuff0"] = dnatsuff

                dnorm = self.model.dists[0].dnorm(n=dnatsuff.shape[0])
                self.wake_data["dnorm0"] = dnorm
                g = (dnatsuff - dnorm)
                #g    = self.model.dists[0].dlogp(self.wake_data["x0"]).mean(0)
                #assert np.allclose(g_exp,g)
                self.wake_data["dlogp0"] = g
                gs.insert(0, g.mean(0))

            else:
                gs.insert(0, np.zeros_like(self.model.dists[0].ps))

        gs = np.concatenate(gs)

        self.gradient_step(gs, it)
Ejemplo n.º 34
0
def test_slogdet_3d():
    fun = lambda x: np.sum(np.linalg.slogdet(x)[1])
    mat = np.concatenate([(rand_psd(5) + 5 * np.eye(5))[None, ...]
                          for _ in range(3)])
    check_grads(fun)(mat)
Ejemplo n.º 35
0
def test_cholesky_broadcast():
    fun = lambda A: np.linalg.cholesky(A)
    A = np.concatenate([rand_psd(6)[None, :, :] for i in range(3)], axis=0)
    check_symmetric_matrix_grads(fun)(A)
Ejemplo n.º 36
0
def compare_2d3d(func1, func2, **kwargs):
    view = [20, -50]
    if 'view' in kwargs:
        view = kwargs['view']

    # construct figure
    fig = plt.figure(figsize=(12, 4))

    # remove whitespace from figure
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)  # remove whitespace
    fig.subplots_adjust(wspace=0.01, hspace=0.01)

    # create subplot with 3 panels, plot input function in center plot
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 2, 4])

    ### draw 2d version ###
    ax1 = plt.subplot(gs[1])
    grad = compute_grad(func1)

    # generate a range of values over which to plot input function, and derivatives
    w_plot = np.linspace(-3, 3, 200)  # input range for original function
    g_plot = func1(w_plot)
    g_range = max(g_plot) - min(g_plot)  # used for cleaning up final plot
    ggap = g_range * 0.2
    w_vals = np.linspace(-2.5, 2.5, 200)

    # grab the next input/output tangency pair, the center of the next approximation(s)
    w_val = float(0)
    g_val = func1(w_val)

    # plot original function
    ax1.plot(w_plot, g_plot, color='k', zorder=1, linewidth=2)

    # plot axis
    ax1.plot(w_plot, g_plot * 0, color='k', zorder=1, linewidth=1)
    # plot the input/output tangency point
    ax1.scatter(w_val,
                g_val,
                s=80,
                c='lime',
                edgecolor='k',
                linewidth=2,
                zorder=3)  # plot point of tangency

    #### plot first order approximation ####
    # plug input into the first derivative
    g_grad_val = grad(w_val)

    # determine width to plot the approximation -- so its length == width
    width = 4
    div = float(1 + g_grad_val**2)
    w1 = w_val - math.sqrt(width / div)
    w2 = w_val + math.sqrt(width / div)

    # compute first order approximation
    wrange = np.linspace(w1, w2, 100)
    h = g_val + g_grad_val * (wrange - w_val)

    # plot the first order approximation
    ax1.plot(wrange, h, color='lime', alpha=0.5, linewidth=3,
             zorder=2)  # plot approx

    #### clean up panel ####
    # fix viewing limits on panel
    v = 5
    ax1.set_xlim([-v, v])
    ax1.set_ylim([-1 - 0.3, v - 0.3])

    # label axes
    ax1.set_xlabel('$w$', fontsize=12, labelpad=-60)
    ax1.set_ylabel('$g(w)$', fontsize=25, rotation=0, labelpad=50)
    ax1.grid(False)
    ax1.yaxis.set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['left'].set_visible(False)

    ### draw 3d version ###
    ax2 = plt.subplot(gs[2], projection='3d')
    grad = compute_grad(func2)
    w_val = [float(0), float(0)]

    # define input space
    w_in = np.linspace(-2, 2, 200)
    w1_vals, w2_vals = np.meshgrid(w_in, w_in)
    w1_vals.shape = (len(w_in)**2, 1)
    w2_vals.shape = (len(w_in)**2, 1)
    w_vals = np.concatenate((w1_vals, w2_vals), axis=1).T
    g_vals = func2(w_vals)

    # evaluation points
    w_val = np.array([float(w_val[0]), float(w_val[1])])
    w_val.shape = (2, 1)
    g_val = func2(w_val)
    grad_val = grad(w_val)
    grad_val.shape = (2, 1)

    # create and evaluate tangent hyperplane
    w_tan = np.linspace(-1, 1, 200)
    w1tan_vals, w2tan_vals = np.meshgrid(w_tan, w_tan)
    w1tan_vals.shape = (len(w_tan)**2, 1)
    w2tan_vals.shape = (len(w_tan)**2, 1)
    wtan_vals = np.concatenate((w1tan_vals, w2tan_vals), axis=1).T

    #h = lambda weh: g_val +  np.dot( (weh - w_val).T,grad_val)
    h = lambda weh: g_val + (weh[0] - w_val[0]) * grad_val[0] + (weh[
        1] - w_val[1]) * grad_val[1]
    h_vals = h(wtan_vals + w_val)
    zmin = min(np.min(h_vals), -0.5)
    zmax = max(np.max(h_vals), +0.5)

    # vals for cost surface, reshape for plot_surface function
    w1_vals.shape = (len(w_in), len(w_in))
    w2_vals.shape = (len(w_in), len(w_in))
    g_vals.shape = (len(w_in), len(w_in))
    w1tan_vals += w_val[0]
    w2tan_vals += w_val[1]
    w1tan_vals.shape = (len(w_tan), len(w_tan))
    w2tan_vals.shape = (len(w_tan), len(w_tan))
    h_vals.shape = (len(w_tan), len(w_tan))

    ### plot function ###
    ax2.plot_surface(w1_vals,
                     w2_vals,
                     g_vals,
                     alpha=0.5,
                     color='w',
                     rstride=25,
                     cstride=25,
                     linewidth=1,
                     edgecolor='k',
                     zorder=2)

    ### plot z=0 plane ###
    ax2.plot_surface(w1_vals,
                     w2_vals,
                     g_vals * 0,
                     alpha=0.1,
                     color='w',
                     zorder=1,
                     rstride=25,
                     cstride=25,
                     linewidth=0.3,
                     edgecolor='k')

    ### plot tangent plane ###
    ax2.plot_surface(w1tan_vals,
                     w2tan_vals,
                     h_vals,
                     alpha=0.4,
                     color='lime',
                     zorder=1,
                     rstride=50,
                     cstride=50,
                     linewidth=1,
                     edgecolor='k')

    # scatter tangency
    ax2.scatter(w_val[0],
                w_val[1],
                g_val,
                s=70,
                c='lime',
                edgecolor='k',
                linewidth=2)

    ### clean up plot ###
    # plot x and y axes, and clean up
    ax2.xaxis.pane.fill = False
    ax2.yaxis.pane.fill = False
    ax2.zaxis.pane.fill = False

    ax2.xaxis.pane.set_edgecolor('white')
    ax2.yaxis.pane.set_edgecolor('white')
    ax2.zaxis.pane.set_edgecolor('white')

    # remove axes lines and tickmarks
    ax2.w_zaxis.line.set_lw(0.)
    ax2.set_zticks([])
    ax2.w_xaxis.line.set_lw(0.)
    ax2.set_xticks([])
    ax2.w_yaxis.line.set_lw(0.)
    ax2.set_yticks([])

    # set viewing angle
    ax2.view_init(20, -65)

    # set vewing limits
    y = 4
    ax2.set_xlim([-y, y])
    ax2.set_ylim([-y, y])
    ax2.set_zlim([zmin, zmax])

    # label plot
    fontsize = 12
    ax2.set_xlabel(r'$w_1$', fontsize=fontsize, labelpad=-35)
    ax2.set_ylabel(r'$w_2$', fontsize=fontsize, rotation=0, labelpad=-40)

    plt.show()
Ejemplo n.º 37
0
def dynamics_fn(t, coords):
    dcoords = autograd.grad(hamiltonian_fn)(coords)
    dqdt, dpdt = np.split(dcoords, 2)
    S = np.concatenate([dpdt, -dqdt], axis=-1)
    return S
Ejemplo n.º 38
0
    def animate_it_3d(self, w_hist, **kwargs):
        self.w_hist = w_hist

        ##### setup figure to plot #####
        # initialize figure
        fig = plt.figure(figsize=(8, 3))
        artist = fig

        # create subplot with 3 panels, plot input function in center plot
        gs = gridspec.GridSpec(1, 2, width_ratios=[2, 1])
        ax1 = plt.subplot(gs[0], projection='3d')
        ax2 = plt.subplot(gs[1])

        # produce color scheme
        s = np.linspace(0, 1, len(self.w_hist[:round(len(self.w_hist) / 2)]))
        s.shape = (len(s), 1)
        t = np.ones(len(self.w_hist[round(len(self.w_hist) / 2):]))
        t.shape = (len(t), 1)
        s = np.vstack((s, t))
        self.colorspec = []
        self.colorspec = np.concatenate((s, np.flipud(s)), 1)
        self.colorspec = np.concatenate((self.colorspec, np.zeros(
            (len(s), 1))), 1)

        # seed left panel plotting range
        viewmax = 3
        if 'viewmax' in kwargs:
            viewmax = kwargs['viewmax']
        r = np.linspace(-viewmax, viewmax, 200)

        # create grid from plotting range
        x1_vals, x2_vals = np.meshgrid(r, r)
        x1_vals.shape = (len(r)**2, 1)
        x2_vals.shape = (len(r)**2, 1)

        x1_vals.shape = (np.size(r), np.size(r))
        x2_vals.shape = (np.size(r), np.size(r))

        # seed left panel view
        view = [20, 50]
        if 'view' in kwargs:
            view = kwargs['view']

        # set zaxis to the left
        self.move_axis_left(ax1)

        # start animation
        num_frames = len(self.w_hist)
        print('starting animation rendering...')

        def animate(k):
            # clear panels
            ax1.cla()

            # set axis in left panel
            self.move_axis_left(ax1)

            # current color
            color = self.colorspec[k]

            # print rendering update
            if np.mod(k + 1, 25) == 0:
                print('rendering animation frame ' + str(k + 1) + ' of ' +
                      str(num_frames))
            if k == num_frames - 1:
                print('animation rendering complete!')
                time.sleep(1.5)
                clear_output()

            ###### make left panel - plot data and fit ######
            # initialize fit
            w = self.w_hist[k]

            # reshape and plot the surface, as well as where the zero-plane is
            y_fit = w[0] + w[1] * x1_vals + w[2] * x2_vals

            # plot cost surface
            ax1.plot_surface(x1_vals,
                             x2_vals,
                             y_fit,
                             alpha=0.1,
                             color=color,
                             rstride=25,
                             cstride=25,
                             linewidth=0.25,
                             edgecolor='k',
                             zorder=2)

            # scatter data
            self.scatter_pts(ax1)
            #ax1.view_init(view[0],view[1])

            # plot connector between points for visualization purposes
            if k == 0:
                w_new = self.w_hist[k]
                g_new = self.least_squares(w_new)[0]
                ax2.scatter(k,
                            g_new,
                            s=0.1,
                            color='w',
                            linewidth=2.5,
                            alpha=0,
                            zorder=1)  # plot approx

            if k > 0:
                w_old = self.w_hist[k - 1]
                w_new = self.w_hist[k]
                g_old = self.least_squares(w_old)[0]
                g_new = self.least_squares(w_new)[0]

                ax2.plot([k - 1, k], [g_old, g_new],
                         color=color,
                         linewidth=2.5,
                         alpha=1,
                         zorder=2)  # plot approx
                ax2.plot([k - 1, k], [g_old, g_new],
                         color='k',
                         linewidth=3.5,
                         alpha=1,
                         zorder=1)  # plot approx

            # set viewing limits for second panel
            ax2.axhline(y=0, color='k', zorder=0, linewidth=0.5)
            ax2.set_xlabel('iteration', fontsize=12)
            ax2.set_ylabel(r'$g(\mathbf{w})$',
                           fontsize=12,
                           rotation=0,
                           labelpad=25)
            ax2.set_xlim([-0.5, len(self.w_hist)])

            # set axis in left panel
            self.move_axis_left(ax1)

            return artist,

        anim = animation.FuncAnimation(fig,
                                       animate,
                                       frames=num_frames,
                                       interval=num_frames,
                                       blit=True)

        return (anim)
Ejemplo n.º 39
0
    def animate_it_2d(self, w_hist, **kwargs):
        self.w_hist = w_hist

        ##### setup figure to plot #####
        # initialize figure
        fig = plt.figure(figsize=(8, 3))
        artist = fig

        # create subplot with 3 panels, plot input function in center plot
        gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1])

        # produce color scheme
        s = np.linspace(0, 1, len(self.w_hist[:round(len(self.w_hist) / 2)]))
        s.shape = (len(s), 1)
        t = np.ones(len(self.w_hist[round(len(self.w_hist) / 2):]))
        t.shape = (len(t), 1)
        s = np.vstack((s, t))
        self.colorspec = []
        self.colorspec = np.concatenate((s, np.flipud(s)), 1)
        self.colorspec = np.concatenate((self.colorspec, np.zeros(
            (len(s), 1))), 1)

        # seed left panel plotting range
        xmin = copy.deepcopy(min(self.x))
        xmax = copy.deepcopy(max(self.x))
        xgap = (xmax - xmin) * 0.1
        xmin -= xgap
        xmax += xgap
        x_fit = np.linspace(xmin, xmax, 300)

        # seed right panel contour plot
        viewmax = 3
        if 'viewmax' in kwargs:
            viewmax = kwargs['viewmax']
        view = [20, 100]
        if 'view' in kwargs:
            view = kwargs['view']
        num_contours = 15
        if 'num_contours' in kwargs:
            num_contours = kwargs['num_contours']
        self.contour_plot(ax2, viewmax, num_contours)

        # start animation
        num_frames = len(self.w_hist)
        print('starting animation rendering...')

        def animate(k):
            # clear panels
            ax1.cla()

            # current color
            color = self.colorspec[k]

            # print rendering update
            if np.mod(k + 1, 25) == 0:
                print('rendering animation frame ' + str(k + 1) + ' of ' +
                      str(num_frames))
            if k == num_frames - 1:
                print('animation rendering complete!')
                time.sleep(1.5)
                clear_output()

            ###### make left panel - plot data and fit ######
            # initialize fit
            w = self.w_hist[k]
            y_fit = w[0] + x_fit * w[1]

            # scatter data
            self.scatter_pts(ax1)

            # plot fit to data
            ax1.plot(x_fit, y_fit, color=color, linewidth=3)

            ###### make right panel - plot contour and steps ######
            if k == 0:
                ax2.scatter(w[0],
                            w[1],
                            s=90,
                            facecolor=color,
                            edgecolor='k',
                            linewidth=0.5,
                            zorder=3)
            if k > 0 and k < num_frames:
                self.plot_pts_on_contour(ax2, k, color)
            if k == num_frames - 1:
                ax2.scatter(w[0],
                            w[1],
                            s=90,
                            facecolor=color,
                            edgecolor='k',
                            linewidth=0.5,
                            zorder=3)

            return artist,

        anim = animation.FuncAnimation(fig,
                                       animate,
                                       frames=num_frames,
                                       interval=num_frames,
                                       blit=True)

        return (anim)
Ejemplo n.º 40
0
 def tile_nS_nT_nN(kernel):
     row = np.concatenate([kernel for idim in range(nS * nT)],
                          axis=0)[np.newaxis, :]
     tiled = np.concatenate([row for irow in range(nN)], axis=0)
     return tiled
Ejemplo n.º 41
0
 def compute_fprime_m_(Eta, Xi, s02):
     return sim_utils.fprime_miller_troyer(
         Eta, Xi**2 + np.concatenate([s02
                                      for ipixel in range(nS * nT)])) * Xi
Ejemplo n.º 42
0
    def wake(self, wake_data, it):

        ddc = self.reg
        self.wake_data = wake_data.copy()

        gs = []

        nl_obs = self.nlayer - 1

        mean_name_higher = "mx%d_x%d" % (self.nlayer - 2, nl_obs)
        fun_name = "x%d->x%d" % (nl_obs, self.nlayer - 2)
        self.wake_data[mean_name_higher] = ddc.predict(self.wake_data,
                                                       fun_name)

        if self.layer_plastic[-1]:

            grad_name = "x%d->dnat%d" % (self.nlayer - 2, nl_obs)
            dnat = self.approx_E(mean_name_higher, grad_name)
            #dnat  = self.model.dists[-1].dnat(self.wake_data["x%d"%(self.nlayer-2)])

            grad_name = "x%d->dnorm%d" % (self.nlayer - 2, nl_obs)
            dnorm = self.approx_E(mean_name_higher, grad_name)
            #dnorm  = self.model.dists[-1].dnorm(self.wake_data["x%d"%(self.nlayer-2)])

            suff = self.model.dists[-1].suff(self.wake_data["x%d" % (nl_obs)])
            g = (self.model.dists[-1].dnatsuff_from_dnatsuff(dnat, suff) -
                 dnorm)
            self.wake_data["dlogp%d" % (nl_obs)] = g
            gs.insert(0, g.mean(0))
        else:
            gs.insert(0, np.zeros_like(self.model.dists[-1].ps))

        if self.nlayer > 1:

            for i in range(self.nlayer - 2, 0, -1):

                mean_name_lower = mean_name_higher
                mean_name_higher = "mx%d_x%d" % (i - 1, nl_obs)
                fun_name = "mx%d_x%d->x%d" % (i, nl_obs, i - 1)
                self.wake_data[mean_name_higher] = ddc.predict(
                    self.wake_data, fun_name)

                if self.layer_plastic[i]:

                    grad_name = "x%d->dnatsuff%d" % (i, i)
                    dnatsuff = self.approx_E(mean_name_lower, grad_name)
                    #dnatsuff = self.model.dists[i].dnatsuff(self.wake_data["x%d"%(i-1)], self.wake_data["x%d"%(i)])

                    grad_name = "x%d->dnorm%d" % (i - 1, i)
                    dnorm = self.approx_E(mean_name_higher, grad_name)
                    #dnorm    = self.model.dists[i].dnorm(self.wake_data["x%d"%(i-1)])

                    g = (dnatsuff - dnorm)
                    #g2 = (self.model.dists[i].dlogp(self.wake_data["x%d"%(i-1)], self.wake_data["x%d"%(i)])).mean(0)
                    #assert np.allclose(g, g2)
                    self.wake_data["dlogp%d" % i] = g
                    gs.insert(0, g.mean(0))

                else:
                    gs.insert(0, np.zeros_like(self.model.dists[i].ps))

            if self.layer_plastic[0]:

                grad_name = "x0->fsuff0"
                suff = self.approx_E(mean_name_higher, grad_name)
                #suff = self.model.dists[0].suff(self.wake_data["x0"])

                dnat = self.model.dists[0].dnat()
                dnatsuff = self.model.dists[0].dnatsuff_from_dnatsuff(
                    dnat, suff)
                dnorm = self.model.dists[0].dnorm()
                g = (dnatsuff - dnorm)
                #g    = self.model.dists[0].dlogp(self.wake_data["x0"]).mean(0)
                #assert np.allclose(g_exp,g)
                self.wake_data["dlogp0"] = g
                gs.insert(0, g.mean(0))

            else:
                gs.insert(0, np.zeros_like(self.model.dists[0].ps))

        gs = np.concatenate(gs)

        self.gradient_step(gs, it)
Ejemplo n.º 43
0
    xm = np.array(x)[:, 0] / 50  # 50 pixels / meter
    track_k = TrackCurvature(xm)
    Nx = TrackNormal(xm)
    u = 1j * Nx
    np.savetxt("track_x.txt",
               np.vstack([np.real(xm), np.imag(xm)]).T.reshape(-1),
               newline=",\n")
    np.savetxt("track_u.txt",
               np.vstack([np.real(u), np.imag(u)]).T.reshape(-1),
               newline=",\n")
    np.savetxt("track_k.txt", track_k, newline=",\n")

    ye, val, stuff = OptimizeTrack(xm, 1.0, 0.05, 0.4)
    psie = RelativePsie(ye, xm)

    rx = u * ye + xm
    raceline_k = TrackCurvature(rx)

    np.savetxt("raceline_k.txt", raceline_k, newline=",\n")
    np.savetxt("raceline_ye.txt", ye, newline=",\n")
    np.savetxt("raceline_psie.txt", psie, newline=",\n")
    np.savetxt(
        "raceline_ds.txt",
        np.abs(  # distance between successive points
            np.concatenate([rx[1:] - rx[:-1], rx[:1] - rx[-1:]])),
        newline=",\n")

    print "saved track_k.txt"
    print "saved raceline_{k, ye, psie}.txt"
Ejemplo n.º 44
0
def ReverseBinaryFlip(z, j):
    return np.concatenate(reversed(ReverseBinarySplit(z, j)), -1)
Ejemplo n.º 45
0
def reconstruct_fullfield(fname,
                          theta_st=0,
                          theta_end=PI,
                          n_epochs='auto',
                          crit_conv_rate=0.03,
                          max_nepochs=200,
                          alpha=1e-7,
                          alpha_d=None,
                          alpha_b=None,
                          gamma=1e-6,
                          learning_rate=1.0,
                          output_folder=None,
                          minibatch_size=None,
                          save_intermediate=False,
                          full_intermediate=False,
                          energy_ev=5000,
                          psize_cm=1e-7,
                          n_epochs_mask_release=None,
                          cpu_only=False,
                          save_path='.',
                          phantom_path='phantom',
                          shrink_cycle=20,
                          core_parallelization=True,
                          free_prop_cm=None,
                          multiscale_level=1,
                          n_epoch_final_pass=None,
                          initial_guess=None,
                          n_batch_per_update=5,
                          dynamic_rate=True,
                          probe_type='plane',
                          probe_initial=None,
                          probe_learning_rate=1e-3,
                          pupil_function=None,
                          theta_downsample=None,
                          forward_algorithm='fresnel',
                          random_theta=True,
                          object_type='normal',
                          kernel_size=17,
                          debug=False,
                          **kwargs):
    """
    Reconstruct a beyond depth-of-focus object.
    :param fname: Filename and path of raw data file. Must be in HDF5 format.
    :param theta_st: Starting rotation angle.
    :param theta_end: Ending rotation angle.
    :param n_epochs: Number of epochs to be executed. If given 'auto', optimizer will stop
                     when reduction rate of loss function goes below crit_conv_rate.
    :param crit_conv_rate: Reduction rate of loss function below which the optimizer should
                           stop.
    :param max_nepochs: The maximum number of epochs to be executed if n_epochs is 'auto'.
    :param alpha: Weighting coefficient for both delta and beta regularizer. Should be None
                  if alpha_d and alpha_b are specified.
    :param alpha_d: Weighting coefficient for delta regularizer.
    :param alpha_b: Weighting coefficient for beta regularizer.
    :param gamma: Weighting coefficient for TV regularizer.
    :param learning_rate: Learning rate of ADAM.
    :param output_folder: Name of output folder. Put None for auto-generated pattern.
    :param downsample: Downsampling (not implemented yet).
    :param minibatch_size: Size of minibatch.
    :param save_intermediate: Whether to save the object after each epoch.
    :param energy_ev: Beam energy in eV.
    :param psize_cm: Pixel size in cm.
    :param n_epochs_mask_release: The number of epochs after which the finite support mask
                                  is released. Put None to disable this feature.
    :param cpu_only: Whether to disable GPU.
    :param save_path: The location of finite support mask, the prefix of output_folder and
                      other metadata.
    :param phantom_path: The location of phantom objects (for test version only).
    :param shrink_cycle: Shrink-wrap is executed per every this number of epochs.
    :param core_parallelization: Whether to use Horovod for parallelized computation within
                                 this function.
    :param free_prop_cm: The distance to propagate the wavefront in free space after exiting
                         the sample, in cm.
    :param multiscale_level: The level of multiscale processing. When this number is m and
                             m > 1, m - 1 low-resolution reconstructions will be performed
                             before reconstructing with the original resolution. The downsampling
                             factor for these coarse reconstructions will be [2^(m - 1),
                             2^(m - 2), ..., 2^1].
    :param n_epoch_final_pass: specify a number of iterations for the final pass if multiscale
                               is activated. If None, it will be the same as n_epoch.
    :param initial_guess: supply an initial guess. If None, object will be initialized with noises.
    :param n_batch_per_update: number of minibatches during which gradients are accumulated, after
                               which obj is updated.
    :param dynamic_rate: when n_batch_per_update > 1, adjust learning rate dynamically to allow it
                         to decrease with epoch number
    :param probe_type: type of wavefront. Can be 'plane', '  fixed', or 'optimizable'. If 'optimizable',
                           the probe function will be optimized along with the object.
    :param probe_initial: can be provided for 'optimizable' probe_type, and must be provided for
                              'fixed'.
    """
    def forward_pass(obj_delta, obj_beta, this_ind_batch):
        obj_stack = np.stack([obj_delta, obj_beta], axis=3)
        obj_rot_batch = []
        for i in range(minibatch_size):
            obj_rot_batch.append(
                apply_rotation(
                    obj_stack, coord_ls[this_ind_batch[i]],
                    'arrsize_{}_{}_{}_ntheta_{}'.format(
                        dim_y, dim_x, dim_x, n_theta)))
        obj_rot_batch = np.stack(obj_rot_batch)

        exiting_batch = multislice_propagate_cnn(obj_rot_batch[:, :, :, :, 0],
                                                 obj_rot_batch[:, :, :, :, 1],
                                                 probe_real,
                                                 probe_imag,
                                                 energy_ev,
                                                 [psize_cm * ds_level] * 3,
                                                 free_prop_cm=free_prop_cm,
                                                 kernel_size=kernel_size)
        return exiting_batch

    def calculate_loss(obj_delta, obj_beta, this_ind_batch, this_prj_batch):

        obj_stack = np.stack([obj_delta, obj_beta], axis=3)
        obj_rot_batch = []
        for i in range(minibatch_size):
            obj_rot_batch.append(
                apply_rotation(
                    obj_stack, coord_ls[this_ind_batch[i]],
                    'arrsize_{}_{}_{}_ntheta_{}'.format(
                        dim_y, dim_x, dim_x, n_theta)))
        obj_rot_batch = np.stack(obj_rot_batch)

        exiting_batch = multislice_propagate_cnn(obj_rot_batch[:, :, :, :, 0],
                                                 obj_rot_batch[:, :, :, :, 1],
                                                 probe_real,
                                                 probe_imag,
                                                 energy_ev,
                                                 [psize_cm * ds_level] * 3,
                                                 free_prop_cm=free_prop_cm,
                                                 kernel_size=kernel_size)
        loss = np.mean((np.abs(exiting_batch) - np.abs(this_prj_batch))**2)

        if alpha_d is None:
            reg_term = alpha * (np.sum(np.abs(obj_delta)) + np.sum(
                np.abs(obj_delta))) + gamma * total_variation_3d(obj_delta)
        else:
            if gamma == 0:
                reg_term = alpha_d * np.sum(
                    np.abs(obj_delta)) + alpha_b * np.sum(np.abs(obj_beta))
            else:
                reg_term = alpha_d * np.sum(
                    np.abs(obj_delta)) + alpha_b * np.sum(np.abs(
                        obj_beta)) + gamma * total_variation_3d(obj_delta)
        loss = loss + reg_term
        print('Current loss: {}'.format(loss))

        return loss

    t_zero = time.time()
    try:
        comm = MPI.COMM_WORLD
        size = comm.Get_size()
        rank = comm.Get_rank()
        mpi_ok = True
    except:
        from pseudo import Mpi
        comm = Mpi()
        size = 1
        rank = 0
        mpi_ok = False

    # read data
    t0 = time.time()
    print_flush('Reading data...', 0, rank)
    f = h5py.File(os.path.join(save_path, fname), 'r')
    prj_0 = f['exchange/data'][...].astype('complex64')
    theta = -np.linspace(theta_st, theta_end, prj_0.shape[0], dtype='float32')
    n_theta = len(theta)
    prj_theta_ind = np.arange(n_theta, dtype=int)
    if theta_downsample is not None:
        prj_0 = prj_0[::theta_downsample]
        theta = theta[::theta_downsample]
        prj_theta_ind = prj_theta_ind[::theta_downsample]
        n_theta = len(theta)
    original_shape = prj_0.shape
    comm.Barrier()
    print_flush('Data reading: {} s'.format(time.time() - t0), 0, rank)
    print_flush('Data shape: {}'.format(original_shape), 0, rank)
    comm.Barrier()

    initializer_flag = False

    if output_folder is None:
        output_folder = 'recon_360_minibatch_{}_' \
                        'mskrls_{}_' \
                        'shrink_{}_' \
                        'iter_{}_' \
                        'alphad_{}_' \
                        'alphab_{}_' \
                        'gamma_{}_' \
                        'rate_{}_' \
                        'energy_{}_' \
                        'size_{}_' \
                        'ntheta_{}_' \
                        'prop_{}_' \
                        'ms_{}_' \
                        'cpu_{}' \
            .format(minibatch_size, n_epochs_mask_release, shrink_cycle,
                    n_epochs, alpha_d, alpha_b,
                    gamma, learning_rate, energy_ev,
                    prj_0.shape[-1], prj_0.shape[0], free_prop_cm,
                    multiscale_level, cpu_only)
        if abs(PI - theta_end) < 1e-3:
            output_folder += '_180'

    if save_path != '.':
        output_folder = os.path.join(save_path, output_folder)

    for ds_level in range(multiscale_level - 1, -1, -1):

        ds_level = 2**ds_level
        print_flush('Multiscale downsampling level: {}'.format(ds_level), 0,
                    rank)
        comm.Barrier()

        # downsample data
        prj = np.copy(prj_0)
        if ds_level > 1:
            prj = prj[:, ::ds_level, ::ds_level]
            prj = prj.astype('complex64')
        comm.Barrier()

        ind_ls = np.arange(n_theta)
        np.random.shuffle(ind_ls)
        n_tot_per_batch = size * minibatch_size
        if n_theta % n_tot_per_batch > 0:
            ind_ls = np.concatenate(
                ind_ls, ind_ls[:n_tot_per_batch - n_theta % n_tot_per_batch])
        ind_ls = split_tasks(ind_ls, n_tot_per_batch)
        ind_ls = [np.sort(x) for x in ind_ls]
        print(len(ind_ls), n_theta % n_tot_per_batch)

        dim_y, dim_x = prj.shape[-2:]
        comm.Barrier()

        # read rotation data
        try:
            coord_ls = read_all_origin_coords(
                'arrsize_{}_{}_{}_ntheta_{}'.format(dim_y, dim_x, dim_x,
                                                    n_theta), n_theta)
        except:
            save_rotation_lookup([dim_y, dim_x, dim_x], n_theta)
            coord_ls = read_all_origin_coords(
                'arrsize_{}_{}_{}_ntheta_{}'.format(dim_y, dim_x, dim_x,
                                                    n_theta), n_theta)

        if minibatch_size is None:
            minibatch_size = n_theta

        if n_epochs_mask_release is None:
            n_epochs_mask_release = np.inf

        try:
            mask = dxchange.read_tiff_stack(
                os.path.join(save_path, 'fin_sup_mask', 'mask_00000.tiff'),
                range(prj_0.shape[1]), 5)
        except:
            try:
                mask = dxchange.read_tiff(
                    os.path.join(save_path, 'fin_sup_mask', 'mask.tiff'))
            except:
                obj_pr = dxchange.read_tiff_stack(
                    os.path.join(save_path, 'paganin_obj/recon_00000.tiff'),
                    range(prj_0.shape[1]), 5)
                obj_pr = gaussian_filter(np.abs(obj_pr),
                                         sigma=3,
                                         mode='constant')
                mask = np.zeros_like(obj_pr)
                mask[obj_pr > 1e-5] = 1
                dxchange.write_tiff_stack(mask,
                                          os.path.join(save_path,
                                                       'fin_sup_mask/mask'),
                                          dtype='float32',
                                          overwrite=True)
        if ds_level > 1:
            mask = mask[::ds_level, ::ds_level, ::ds_level]
        dim_z = mask.shape[-1]

        # unify random seed for all threads
        comm.Barrier()
        seed = int(time.time() / 60)
        np.random.seed(seed)
        comm.Barrier()

        if initializer_flag == False:
            if initial_guess is None:
                print_flush('Initializing with Gaussian random.', 0, rank)
                obj_delta = np.random.normal(
                    size=[dim_y, dim_x, dim_z], loc=8.7e-7, scale=1e-7) * mask
                obj_beta = np.random.normal(
                    size=[dim_y, dim_x, dim_z], loc=5.1e-8, scale=1e-8) * mask
                obj_delta[obj_delta < 0] = 0
                obj_beta[obj_beta < 0] = 0
            else:
                print_flush('Using supplied initial guess.', 0, rank)
                sys.stdout.flush()
                obj_delta = initial_guess[0]
                obj_beta = initial_guess[1]
        else:
            print_flush('Initializing with Gaussian random.', 0, rank)
            obj_delta = dxchange.read_tiff(
                os.path.join(output_folder,
                             'delta_ds_{}.tiff'.format(ds_level * 2)))
            obj_beta = dxchange.read_tiff(
                os.path.join(output_folder,
                             'beta_ds_{}.tiff'.format(ds_level * 2)))
            obj_delta = upsample_2x(obj_delta)
            obj_beta = upsample_2x(obj_beta)
            obj_delta += np.random.normal(
                size=[dim_y, dim_x, dim_z], loc=8.7e-7, scale=1e-7) * mask
            obj_beta += np.random.normal(
                size=[dim_y, dim_x, dim_z], loc=5.1e-8, scale=1e-8) * mask
            obj_delta[obj_delta < 0] = 0
            obj_beta[obj_beta < 0] = 0
        obj_size = obj_delta.shape
        if object_type == 'phase_only':
            obj_beta[...] = 0
        elif object_type == 'absorption_only':
            obj_delta[...] = 0
        # ====================================================

        if probe_type == 'plane':
            probe_real = np.ones([dim_y, dim_x])
            probe_imag = np.zeros([dim_y, dim_x])
        elif probe_type == 'optimizable':
            if probe_initial is not None:
                probe_mag, probe_phase = probe_initial
                probe_real, probe_imag = mag_phase_to_real_imag(
                    probe_mag, probe_phase)
            else:
                # probe_mag = np.ones([dim_y, dim_x])
                # probe_phase = np.zeros([dim_y, dim_x])
                back_prop_cm = (free_prop_cm + (psize_cm * obj_size[2])
                                ) if free_prop_cm is not None else (
                                    psize_cm * obj_size[2])
                probe_init = create_probe_initial_guess(
                    os.path.join(save_path, fname), back_prop_cm * 1.e7,
                    energy_ev, psize_cm * 1.e7)
                probe_real = probe_init.real
                probe_imag = probe_init.imag
            if pupil_function is not None:
                probe_real = probe_real * pupil_function
                probe_imag = probe_imag * pupil_function
        elif probe_type == 'fixed':
            probe_mag, probe_phase = probe_initial
            probe_real, probe_imag = mag_phase_to_real_imag(
                probe_mag, probe_phase)
        elif probe_type == 'point':
            # this should be in spherical coordinates
            probe_real = np.ones([dim_y, dim_x])
            probe_imag = np.zeros([dim_y, dim_x])
        elif probe_type == 'gaussian':
            probe_mag_sigma = kwargs['probe_mag_sigma']
            probe_phase_sigma = kwargs['probe_phase_sigma']
            probe_phase_max = kwargs['probe_phase_max']
            py = np.arange(obj_size[0]) - (obj_size[0] - 1.) / 2
            px = np.arange(obj_size[1]) - (obj_size[1] - 1.) / 2
            pxx, pyy = np.meshgrid(px, py)
            probe_mag = np.exp(-(pxx**2 + pyy**2) / (2 * probe_mag_sigma**2))
            probe_phase = probe_phase_max * np.exp(-(pxx**2 + pyy**2) /
                                                   (2 * probe_phase_sigma**2))
            probe_real, probe_imag = mag_phase_to_real_imag(
                probe_mag, probe_phase)
        else:
            raise ValueError(
                'Invalid wavefront type. Choose from \'plane\', \'fixed\', \'optimizable\'.'
            )

        # =============finite support===================
        obj_delta = obj_delta * mask
        obj_beta = obj_beta * mask
        obj_delta = np.clip(obj_delta, 0, None)
        obj_beta = np.clip(obj_beta, 0, None)
        # ==============================================

        # generate Fresnel kernel
        voxel_nm = np.array([psize_cm] * 3) * 1.e7 * ds_level
        lmbda_nm = 1240. / energy_ev
        delta_nm = voxel_nm[-1]
        h = get_kernel(delta_nm, lmbda_nm, voxel_nm, [dim_y, dim_y, dim_x])

        loss_grad = grad(calculate_loss, [0, 1])

        print_flush('Optimizer started.', 0, rank)
        if rank == 0:
            create_summary(output_folder, locals(), preset='fullfield')

        cont = True
        i_epoch = 0
        while cont:
            m, v = (None, None)
            t0 = time.time()
            for i_batch in range(len(ind_ls)):

                t00 = time.time()
                this_ind_batch = ind_ls[i_batch][rank *
                                                 minibatch_size:(rank + 1) *
                                                 minibatch_size]
                this_prj_batch = prj[this_ind_batch]
                grads = loss_grad(obj_delta, obj_beta, this_ind_batch,
                                  this_prj_batch)
                grads = np.array(grads)
                this_grads = np.copy(grads)
                if mpi_ok:
                    grads = np.zeros_like(this_grads)
                    comm.Allreduce(this_grads, grads)
                    grads = grads / size
                (obj_delta, obj_beta), m, v = apply_gradient_adam(
                    np.array([obj_delta, obj_beta]),
                    grads,
                    i_batch,
                    m,
                    v,
                    step_size=learning_rate)

                dxchange.write_tiff(obj_delta,
                                    fname=os.path.join(
                                        output_folder, 'intermediate',
                                        'current'.format(ds_level)),
                                    dtype='float32',
                                    overwrite=True)
                # finite support
                obj_delta = obj_delta * mask
                obj_beta = obj_beta * mask
                obj_delta = np.clip(obj_delta, 0, None)
                obj_beta = np.clip(obj_beta, 0, None)

                # shrink wrap
                if shrink_cycle is not None:
                    if i_epoch >= shrink_cycle:
                        boolean = obj_delta > 1e-15
                    mask = mask * boolean

                print_flush('Minibatch done in {} s (rank {})'.format(
                    time.time() - t00, rank))

                if i_batch % 10 == 0 and debug:
                    temp_exit = forward_pass(obj_delta, obj_beta,
                                             this_ind_batch)
                    dxchange.write_tiff(abs(temp_exit),
                                        os.path.join(
                                            output_folder, 'exits',
                                            '{}-{}'.format(i_epoch, i_batch)),
                                        dtype='float32',
                                        overwrite=True)

            if n_epochs == 'auto':
                pass
            else:
                if i_epoch == n_epochs - 1: cont = False
            i_epoch = i_epoch + 1

            print_flush(
                'Epoch {} (rank {}); loss = {}; Delta-t = {} s; current time = {}.'
                .format(
                    i_epoch, rank,
                    calculate_loss(obj_delta, obj_beta, this_ind_batch,
                                   this_prj_batch),
                    time.time() - t0,
                    time.time() - t_zero))
        dxchange.write_tiff(obj_delta,
                            fname=os.path.join(output_folder,
                                               'delta_ds_{}'.format(ds_level)),
                            dtype='float32',
                            overwrite=True)
        dxchange.write_tiff(obj_beta,
                            fname=os.path.join(output_folder,
                                               'beta_ds_{}'.format(ds_level)),
                            dtype='float32',
                            overwrite=True)

        print_flush('Current iteration finished.', 0, rank)
Ejemplo n.º 46
0
def test_lds_log_probability_perf(T=1000, D=10, N_iter=10):
    """
    Compare performance of banded method vs message passing in pylds.
    """
    print("Comparing methods for T={} D={}".format(T, D))

    from pylds.lds_messages_interface import kalman_info_filter, kalman_filter

    # Convert LDS parameters into info form for pylds
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    Qis = np.matmul(Qi_sqrts, np.swapaxes(Qi_sqrts, -1, -2))
    Ris = np.matmul(Ri_sqrts, np.swapaxes(Ri_sqrts, -1, -2))
    x = npr.randn(T, D)

    print("Timing banded method")
    start = time.time()
    for itr in range(N_iter):
        lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))

    # Compare to Kalman Filter
    mu_init = np.zeros(D)
    sigma_init = np.eye(D)
    Bs = np.ones((D, 1))
    sigma_states = np.linalg.inv(Qis)
    Cs = np.eye(D)
    Ds = np.zeros((D, 1))
    sigma_obs = np.linalg.inv(Ris)
    inputs = bs
    data = ms

    print("Timing PyLDS message passing (kalman_filter)")
    start = time.time()
    for itr in range(N_iter):
        kalman_filter(mu_init, sigma_init,
                      np.concatenate([As, np.eye(D)[None, :, :]]), Bs,
                      np.concatenate([sigma_states,
                                      np.eye(D)[None, :, :]]), Cs, Ds,
                      sigma_obs, inputs, data)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))

    # Info form comparison
    J_init = np.zeros((D, D))
    h_init = np.zeros(D)
    log_Z_init = 0

    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(
        As, bs, Qi_sqrts, ms, Ri_sqrts)
    J_pair_21 = J_lower_diag
    J_pair_22 = J_diag[1:]
    J_pair_11 = J_diag[:-1]
    J_pair_11[1:] = 0
    h_pair_2 = h[1:]
    h_pair_1 = h[:-1]
    h_pair_1[1:] = 0
    log_Z_pair = 0

    J_node = np.zeros((T, D, D))
    h_node = np.zeros((T, D))
    log_Z_node = 0

    print("Timing PyLDS message passing (kalman_info_filter)")
    start = time.time()
    for itr in range(N_iter):
        kalman_info_filter(J_init, h_init, log_Z_init, J_pair_11, J_pair_21,
                           J_pair_22, h_pair_1, h_pair_2, log_Z_pair, J_node,
                           h_node, log_Z_node)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))
Ejemplo n.º 47
0
# Author: Pierre Ablin <*****@*****.**>
# License: MIT

# Example with several variables

import autograd.numpy as np

from autoptim import minimize

n = 1000
n_components = 3

x = np.concatenate(
    (np.random.randn(n) - 1, 3 * np.random.randn(n), np.random.randn(n) + 2))

# Here, the model should fit both the means and the variances. Using
# scipy.optimize.minimize, one would have to vectorize by hand these variables.


def loss(means, variances, x):
    tmp = np.zeros(n_components * n)
    for m, v in zip(means, variances):
        tmp += np.exp(-(x - m)**2 / (2 * v**2)) / v
    return -np.sum(np.log(tmp))


# autoptim can handle lists of unknown variables

means0 = np.random.randn(n_components)
variances0 = np.random.rand(n_components)
optim_vars = [means0, variances0]
Ejemplo n.º 48
0
        x = x[None]
    if fulldim:
        x_resc = x * (2. * np.pi)
    else:
        x_resc = np.copy(x[:, relevant_dims]) * (2. * np.pi)  # [0, 1] -> [0, 2pi]
    f_x = np.sin(x_resc[:, 0])[:, None] * np.prod(np.sin(x_resc), axis=1)[:, None] * 10  # rescale by factor
    if noisy:
        noise = np.random.normal(loc=0., scale=scale, size=f_x.shape[0])[:, None]
        y = f_x + noise
        return y
    return f_x
func = lambda x: ProductSines10D(x, noisy=False, fulldim=True)
func_grad = grad(func)
minimizer = np.ones(shape=[1, int(100)], dtype=np.float64) * (np.pi * (3. / 2.)) / (2*np.pi)
g0 = func_grad(minimizer)

np.random.seed(123)
num = int(5e06)
shape_x = [num, int(10)]
x = np.random.uniform(low=0., high=1., size=np.prod(shape_x)).reshape(shape_x)
grads = []
for x_i, i in zip(list(x), list(range(num))):
    print(i)
    grad_i = func_grad(x_i[None])
    grads.append(grad_i)



grads_all = np.concatenate(grads, axis=0)
# f_sample = objective.f(x, noisy=False, fulldim=True)
# f_min_sample = f_sample.min()
Ejemplo n.º 49
0
    def __init__(self, frame, sky_coord, observations, entropic=False):
        """Source intialized with a single pixel

        Parameters
        ----------
        frame: `~scarlet.Frame`
            The frame of the model
        sky_coord: tuple
            Center of the source
        observations: instance or list of `~scarlet.Observation`
            Observation(s) to initialize this source
        entropic: `bool`
            Whether or not to enforce more information on SED
        """
        C, Ny, Nx = frame.shape
        self.entropic = entropic
        self.center = np.array(frame.get_pixel(sky_coord), dtype="float")

        # initialize SED from sky_coord
        try:
            iter(observations)
        except TypeError:
            observations = [observations]

        # determine initial SED from peak position
        # SED in the frame for source detection
        seds = []
        for obs in observations:
            _sed = get_psf_sed(sky_coord, obs, frame)
            seds.append(_sed)
        sed = np.concatenate(seds).reshape(-1)

        if np.any(sed <= 0):
            # If the flux in all channels is  <=0,
            # the new sed will be filled with NaN values,
            # which will cause the code to crash later
            msg = "Zero or negative SED {} at y={}, x={}".format(
                sed, *sky_coord)
            if np.all(sed <= 0):
                logger.warning(msg)
            else:
                logger.info(msg)

        # set up parameters
        sed_constraints = PositivityConstraint()
        if entropic:
            sed_constraints = ConstraintChain(sed_constraints,
                                              EntropyConstraint())
        sed = Parameter(
            sed,
            name="sed",
            step=partial(relative_step, factor=1e-2),
            constraint=sed_constraints,
        )

        center = Parameter(self.center, name="center", step=1e-1)

        # define bbox
        pixel_center = tuple(np.round(center).astype("int"))
        front, back = 0, C
        bottom = pixel_center[0] - frame.psf.shape[1] // 2
        top = pixel_center[0] + frame.psf.shape[1] // 2
        left = pixel_center[1] - frame.psf.shape[2] // 2
        right = pixel_center[1] + frame.psf.shape[2] // 2
        bbox = Box.from_bounds((front, back), (bottom, top), (left, right))

        super().__init__(frame, sed, center, self._psf_wrapper, bbox=bbox)
Ejemplo n.º 50
0
        def animate(k):
            ax1.cla()
            ax2.cla()
            
            # print rendering update            
            if np.mod(k+1,25) == 0:
                print ('rendering animation frame ' + str(k+1) + ' of ' + str(num_frames))
            if k == num_frames - 1:
                print ('animation rendering complete!')
                time.sleep(1.5)
                clear_output()
                
            # plot initial point and evaluation
            if k == 0:
                w_val = self.w_init
                g_val = self.g(w_val)
                ax1.scatter(w_val,g_val,s = 100,c = 'm',edgecolor = 'k',linewidth = 0.7,zorder = 2)            # plot point of tangency
                # ax1.scatter(w_val,0,s = 100,c = 'm',edgecolor = 'k',linewidth = 0.7, zorder = 2, marker = 'X')
                # plot function 
                ax1.plot(w_plot,g_plot,color = 'k',zorder = 0)               # plot function

            # plot function alone first along with initial point
            if k > 0:
                alpha = self.steplength_range[k-1]
                
                # run gradient descent method
                self.w_hist = []
                self.run_gradient_descent(alpha = alpha)
                
                # plot function
                self.plot_function(ax1)
        
                # colors for points
                s = np.linspace(0,1,len(self.w_hist[:round(len(self.w_hist)/2)]))
                s.shape = (len(s),1)
                t = np.ones(len(self.w_hist[round(len(self.w_hist)/2):]))
                t.shape = (len(t),1)
                s = np.vstack((s,t))
                self.colorspec = []
                self.colorspec = np.concatenate((s,np.flipud(s)),1)
                self.colorspec = np.concatenate((self.colorspec,np.zeros((len(s),1))),1)
        
                # plot everything for each iteration 
                for j in range(len(self.w_hist)):  
                    w_val = self.w_hist[j]
                    g_val = self.g(w_val)
                    grad_val = self.grad(w_val)
                    ax1.scatter(w_val,g_val,s = 90,c = self.colorspec[j],edgecolor = 'k',linewidth = 0.7,zorder = 3)            # plot point of tangency
                    
                    # ax1.scatter(w_val,0,s = 90,facecolor = self.colorspec[j],marker = 'X',edgecolor = 'k',linewidth = 0.7, zorder = 2)
                    
                    # determine width to plot the approximation -- so its length == width defined above
                    div = float(1 + grad_val**2)
                    w1 = w_val - math.sqrt(width/div)
                    w2 = w_val + math.sqrt(width/div)

                    # use point-slope form of line to plot
                    wrange = np.linspace(w1,w2, 100)
                    h = g_val + grad_val*(wrange - w_val)
                
                    # plot tracers connecting consecutive points on the cost (for visualization purposes)
                    if tracers == 'on':
                        if j > 0:
                            w_old = self.w_hist[j-1]
                            w_new = self.w_hist[j]
                            g_old = self.g(w_old)
                            g_new = self.g(w_new)
                            ax1.quiver(w_old, g_old, w_new - w_old, g_new - g_old, scale_units='xy', angles='xy', scale=1, color = self.colorspec[j],linewidth = 1.5,alpha = 0.2,linestyle = '-',headwidth = 4.5,edgecolor = 'k',headlength = 10,headaxislength = 7)
            
                    ### plot all on cost function decrease plot
                    ax2.scatter(j,g_val,s = 90,c = self.colorspec[j],edgecolor = 'k',linewidth = 0.7,zorder = 3)            # plot point of tangency
                    
                    # clean up second axis, set title on first
                    ax2.set_xticks(np.arange(len(self.w_hist)))
                    ax1.set_title(r'$\alpha = $' + r'{:.2f}'.format(alpha),fontsize = 14)

                    # plot connector between points for visualization purposes
                    if j > 0:
                        w_old = self.w_hist[j-1]
                        w_new = self.w_hist[j]
                        g_old = self.g(w_old)
                        g_new = self.g(w_new)
                        ax2.plot([j-1,j],[g_old,g_new],color = self.colorspec[j],linewidth = 2,alpha = 0.4,zorder = 1)      # plot approx
 
            ### clean up function plot ###
            # fix viewing limits on function plot
            #ax1.set_xlim([-3,3])
            #ax1.set_ylim([min(g_plot) - ggap,max(g_plot) + ggap])
            
            # draw axes and labels
            ax1.set_xlabel(r'$w$',fontsize = 13)
            ax1.set_ylabel(r'$g(w)$',fontsize = 13,rotation = 0,labelpad = 25)   

            ax2.set_xlabel('iteration',fontsize = 13)
            ax2.set_ylabel(r'$g(w)$',fontsize = 13,rotation = 0,labelpad = 25)
            ax1.axhline(y=0, color='k',zorder = 0,linewidth = 0.5)
            ax2.axhline(y=0, color='k',zorder = 0,linewidth = 0.5)

            return artist,
Ejemplo n.º 51
0
    #    pickle.dump(lam_list, f)

#############################################
# Nonparametric variational inference code  #
#  --- save posterior parameters            #
#############################################

if args.npvi:

    init_with_mfvi = True
    if init_with_mfvi:
        mfvi_lam = mfvi_init()

        # initialize theta
        theta_mfvi = np.atleast_2d(
            np.concatenate([mfvi_lam[:D], [2 * mfvi_lam[D:].mean()]]))
        mu0 = vi.bbvi_npvi.mogsamples(args.ncomp, theta_mfvi)

        # create npvi object
        theta0 = np.column_stack(
            [mu0, np.ones(args.ncomp) * theta_mfvi[0, -1]])

    else:
        theta0 = np.column_stack(
            [10 * np.random.randn(args.ncomp, D), -2 * np.ones(args.ncomp)])

    # create initial theta and sample
    npvi = vi.NPVI(lnpdf, D=D)
    mu, s2, elbo_vals, theta = npvi.run(theta0.copy(), verbose=False)
    print elbo_vals
Ejemplo n.º 52
0
    def _merge_array(self, *args):

        self.sleep_data["".join(args)] = np.concatenate(
            [self.sleep_data[k] for k in args], -1)
Ejemplo n.º 53
0
def init_gaussian_var_params(D, mean_mean=-1, log_std_mean=-5, scale=0.1, rs=npr.RandomState(0)):
    init_mean    = mean_mean * np.ones(D) + rs.randn(D) * scale
    init_log_std = log_std_mean * np.ones(D) + rs.randn(D) * scale
    return np.concatenate([init_mean, init_log_std])
Ejemplo n.º 54
0
    def wake(self, wake_data, it):

        ddc = self.reg
        self.wake_data = wake_data.copy()

        gs = []

        nl_obs = self.nlayer - 1

        mean_name_higher = "mx%d_x%d" % (self.nlayer - 2, nl_obs)
        fun_name = "x%d->x%d" % (nl_obs, self.nlayer - 2)
        self.wake_data[mean_name_higher] = ddc.predict(self.wake_data,
                                                       fun_name)

        if self.layer_plastic[-1]:

            grad_name = "x%d->dnat%d" % (self.nlayer - 2, nl_obs)
            dnat = self.approx_E(mean_name_higher, grad_name)
            #dnat  = self.model.dists[-1].dnat(self.wake_data["x%d"%(self.nlayer-2)])

            grad_name = "x%d->dnorm%d" % (self.nlayer - 2, nl_obs)
            dnorm = self.approx_E(mean_name_higher, grad_name)
            #dnorm  = self.model.dists[-1].dnorm(self.wake_data["x%d"%(self.nlayer-2)])

            suff = self.model.dists[-1].suff(self.wake_data["x%d" % (nl_obs)])
            g = (self.model.dists[-1].dnatsuff_from_dnatsuff(dnat, suff) -
                 dnorm)
            self.wake_data["dlogp%d" % (nl_obs)] = g
            gs.insert(0, g.mean(0))
        else:
            gs.insert(0, np.zeros_like(self.model.dists[-1].ps))

        if self.layer_plastic[-2]:

            fun_name = "x%d->x%d" % (nl_obs, self.nlayer - 2)
            grad_name = "x%d->dlogp%d" % (self.nlayer - 2, self.nlayer - 2)

            g = self.approx_E(mean_name_higher, grad_name)
            self.wake_data["dlogp%d" % (self.nlayer - 2)] = g
            gs.insert(0, g.mean(0))
        else:
            gs.insert(0, np.zeros_like(self.model.dists[-2].ps))

        for i in range(self.nlayer - 3, -1, -1):
            mean_name = "mx%d_x%d" % (i, nl_obs)
            fun_name = "mx%d_x%d->x%d" % (i + 1, nl_obs, i)
            grad_name = "x%d->dlogp%d" % (i, i)

            self.wake_data[mean_name] = ddc.predict(self.wake_data, fun_name)

            if self.layer_plastic[i]:
                g = self.approx_E(mean_name, grad_name)
                self.wake_data["dlogp%d" % i] = g
                gs.insert(0, g.mean(0))

            else:
                gs.insert(0, np.zeros_like(self.model.dists[i].ps))

        gs = np.concatenate(gs)

        self.gradient_step(gs, it)
    def draw_surface(self, g, ax, **kwargs):
        xmin = -3.1
        xmax = 3.1
        ymin = -3.1
        ymax = 3.1
        if 'xmin' in kwargs:
            xmin = kwargs['xmin']
        if 'xmax' in kwargs:
            xmax = kwargs['xmax']
        if 'ymin' in kwargs:
            ymin = kwargs['ymin']
        if 'ymax' in kwargs:
            ymax = kwargs['ymax']

        #### define input space for function and evaluate ####
        w1 = np.linspace(xmin, xmax, 200)
        w2 = np.linspace(ymin, ymax, 200)
        w1_vals, w2_vals = np.meshgrid(w1, w2)
        w1_vals.shape = (len(w1)**2, 1)
        w2_vals.shape = (len(w2)**2, 1)
        h = np.concatenate((w1_vals, w2_vals), axis=1)
        func_vals = np.asarray([g(np.reshape(s, (2, 1))) for s in h])

        ### plot function as surface ###
        w1_vals.shape = (len(w1), len(w2))
        w2_vals.shape = (len(w1), len(w2))
        func_vals.shape = (len(w1), len(w2))
        ax.plot_surface(w1_vals,
                        w2_vals,
                        func_vals,
                        alpha=0.1,
                        color='w',
                        rstride=25,
                        cstride=25,
                        linewidth=1,
                        edgecolor='k',
                        zorder=2)

        # plot z=0 plane
        ax.plot_surface(w1_vals,
                        w2_vals,
                        func_vals * 0,
                        alpha=0.1,
                        color='w',
                        zorder=1,
                        rstride=25,
                        cstride=25,
                        linewidth=0.3,
                        edgecolor='k')

        # clean up axis
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False

        ax.xaxis.pane.set_edgecolor('white')
        ax.yaxis.pane.set_edgecolor('white')
        ax.zaxis.pane.set_edgecolor('white')

        ax.xaxis._axinfo["grid"]['color'] = (1, 1, 1, 0)
        ax.yaxis._axinfo["grid"]['color'] = (1, 1, 1, 0)
        ax.zaxis._axinfo["grid"]['color'] = (1, 1, 1, 0)

        ax.set_xlabel('$w_0$', fontsize=14)
        ax.set_ylabel('$w_1$', fontsize=14, rotation=0)
        ax.set_title('$g(w_0,w_1)$', fontsize=14)
Ejemplo n.º 56
0
import matplotlib.pyplot as plt
if not os.path.exists('plots'):
    os.mkdir('plots')

# Set priors, dataset, problem
D = 2  # dimensions
K = 3  # number of components, unnecessary ones should go to zero
N = 100  # number of points in synthetic dataset
N_its = 0  # number of updates

# Dataset
centres = [np.array([0., 8.]), np.array([5., 0.])]
covs = [np.eye(2), np.array([[0.6, 0.4], [0.4, 0.6]])]
X1 = multivariate_normal(mean=centres[0], cov=covs[0], size=int(N / 2))
X2 = multivariate_normal(mean=centres[1], cov=covs[1], size=int(N / 2))
X = np.concatenate((X1, X2))

# Variational priors
alpha0 = 1e-3  # as alpha0 -> 0, pi_k -> 0. As alpha0 -> Inf, pi_k -> 1/K
beta0 = 1e-10  # ???
m0 = np.zeros(2)  # zero by convention (symmetry)
W0 = np.eye(2)  #
nu0 = 2  #

# r needs to be randomised (not uniform) because if its all the same nothing
# changes - not sure of the mathematical reasoning for this
r = np.array([np.random.dirichlet(np.ones(K)) for _ in range(N)])
r = [r[:, k] for k in range(K)]

# neaten up the output
from tqdm import tqdm
    def single_input_plot(self, g, weight_histories, cost_histories, **kwargs):
        # adjust viewing range
        wmin = -3.1
        wmax = 3.1
        if 'wmin' in kwargs:
            wmin = kwargs['wmin']
        if 'wmax' in kwargs:
            wmax = kwargs['wmax']

        onerun_perplot = False
        if 'onerun_perplot' in kwargs:
            onerun_perplot = kwargs['onerun_perplot']

        ### initialize figure
        fig = plt.figure(figsize=(9, 4))
        artist = fig

        # remove whitespace from figure
        #fig.subplots_adjust(left=0, right=1, bottom=0, top=1) # remove whitespace
        #fig.subplots_adjust(wspace=0.01,hspace=0.01)

        # create subplot with 2 panels, plot input function in center plot
        gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1])

        ### plot function in both panels
        w_plot = np.linspace(wmin, wmax, 500)
        g_plot = g(w_plot)
        gmin = np.min(g_plot)
        gmax = np.max(g_plot)
        g_range = gmax - gmin
        ggap = g_range * 0.1
        gmin -= ggap
        gmax += ggap

        # plot function, axes lines
        ax1.plot(w_plot, g_plot, color='k', zorder=2)  # plot function
        ax1.axhline(y=0, color='k', zorder=1, linewidth=0.25)
        ax1.axvline(x=0, color='k', zorder=1, linewidth=0.25)
        ax1.set_xlabel(r'$w$', fontsize=13)
        ax1.set_ylabel(r'$g(w)$', fontsize=13, rotation=0, labelpad=25)
        ax1.set_xlim(wmin, wmax)
        ax1.set_ylim(gmin, gmax)

        ax2.plot(w_plot, g_plot, color='k', zorder=2)  # plot function
        ax2.axhline(y=0, color='k', zorder=1, linewidth=0.25)
        ax2.axvline(x=0, color='k', zorder=1, linewidth=0.25)
        ax2.set_xlabel(r'$w$', fontsize=13)
        ax2.set_ylabel(r'$g(w)$', fontsize=13, rotation=0, labelpad=25)
        ax2.set_xlim(wmin, wmax)
        ax2.set_ylim(gmin, gmax)

        #### loop over histories and plot each
        for j in range(len(weight_histories)):
            w_hist = weight_histories[j]
            c_hist = cost_histories[j]

            # colors for points --> green as the algorithm begins, yellow as it converges, red at final point
            s = np.linspace(0, 1, len(w_hist[:round(len(w_hist) / 2)]))
            s.shape = (len(s), 1)
            t = np.ones(len(w_hist[round(len(w_hist) / 2):]))
            t.shape = (len(t), 1)
            s = np.vstack((s, t))
            self.colorspec = []
            self.colorspec = np.concatenate((s, np.flipud(s)), 1)
            self.colorspec = np.concatenate(
                (self.colorspec, np.zeros((len(s), 1))), 1)

            ### plot all history points
            ax = ax2
            if onerun_perplot == True:
                if j == 0:
                    ax = ax1
                if j == 1:
                    ax = ax2
            for k in range(len(w_hist)):
                # pick out current weight and function value from history, then plot
                w_val = w_hist[k]
                g_val = c_hist[k]
                ax.scatter(w_val,
                           g_val,
                           s=90,
                           color=self.colorspec[k],
                           edgecolor='k',
                           linewidth=0.5 * ((1 / (float(k) + 1)))**(0.4),
                           zorder=3,
                           marker='X')  # evaluation on function
                ax.scatter(w_val,
                           0,
                           s=90,
                           facecolor=self.colorspec[k],
                           edgecolor='k',
                           linewidth=0.5 * ((1 / (float(k) + 1)))**(0.4),
                           zorder=3)
Ejemplo n.º 58
0
def TrackSmoothness(x):
    k = TrackCurvature(x)
    kdiff = np.concatenate([k[1:] - k[:-1], k[-1:] - k[:1]])
    return np.sum(kdiff**2)
Ejemplo n.º 59
0
 def pack(x, dc, dv):
     return np.concatenate([x.ravel(), dc.ravel(), dv.ravel()])
Ejemplo n.º 60
0
 def compute_fprime_s_(Eta, Xi, s02):
     s2 = Xi**2 + np.concatenate((s02, s02), axis=0)
     return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2)