Example #1
0
    def plot(self, fits_image, ax, data_ax=None, diff_ax=None, unit_flux=False):
        import matplotlib.pyplot as plt; import seaborn as sns;
        from CelestePy.util.misc import plot_util
        if unit_flux:
            patch, ylim, xlim = self.compute_scatter_on_pixels(fits_image)
        else:
            patch, ylim, xlim = self.compute_model_patch(fits_image)
        cim = ax.imshow(patch, extent=(xlim[0], xlim[1], ylim[0], ylim[1]))
        plot_util.add_colorbar_to_axis(ax, cim)
        ax.set_title("model")

        if data_ax is not None:
            dpatch = fits_image.nelec[ylim[0]:ylim[1], xlim[0]:xlim[1]].copy()
            print "Data patch median: ", np.median(dpatch)
            dpatch -= np.median(dpatch)
            dpatch[dpatch<0] = 0.
            dim = data_ax.imshow(dpatch, extent=(xlim[0], xlim[1], ylim[0], ylim[1]))
            plot_util.add_colorbar_to_axis(data_ax, dim)
            data_ax.set_title("data")

        if diff_ax is not None:
            dpatch = fits_image.nelec[ylim[0]:ylim[1], xlim[0]:xlim[1]].copy()
            dpatch -= np.median(dpatch)
            dpatch[dpatch<0] = 0.
            dim = diff_ax.imshow((dpatch - patch), extent=(xlim[0], xlim[1], ylim[0], ylim[1]))
            plot_util.add_colorbar_to_axis(diff_ax, dim)
            msqe  = np.mean((dpatch - patch)**2)
            smsqe = np.mean((dpatch - patch)**2 / patch)
            diff_ax.set_title("diff, mse = %2.3f"%msqe)
	def loss_func(self, w):
		loss = 0.5 * np.mean((self.train_y_ - np.dot(self.train_x_, w))**2)
		if self.penalty_ == "l1": # Lasso
			loss += self.alpha_ * np.sum(np.abs(w[:-1]))
		elif self.penalty_ == "l2": # Ridge
			loss += 0.5 * self.alpha_ * np.mean(w[:-1]**2)
		return loss
Example #3
0
def BBVI(params,num_samples,num_particles,K,convergence):
    m = np.array([0.,0.])
    v = np.array([0.,0.])
    iterating = 1
    lower_bounds = []
    scaled_lower_bounds = []
    i=0
    while iterating==1:
        params,m,v,LB = iterate(params,num_samples,num_particles,i,m,v)
        #the scaling performs very poorly
        #LB/=M
        i+=1
        lower_bounds.append(-LB)
        if params[1]<=0:
            params = np.random.uniform(0,1,2)
            m = np.array([0.,0.])
            v = np.array([0.,0.])
        if i%10==0:
            print params, LB
        if len(lower_bounds)>K+1:
            lb2= np.mean(np.array(lower_bounds[-K:]))
            lb1 = np.mean(np.array(lower_bounds[-K-1:-1]))
            if abs(lb2-lb1)<convergence:
                iterating = 0
    return params,lower_bounds,i
Example #4
0
def location_mixture_logpdf(samps, locations, location_weights, distr_at_origin, contr_var = False, variant = 1):
#    lpdfs = zeroprop.logpdf()
    diff = samps - locations[:, np.newaxis, :]
    lpdfs = distr_at_origin.logpdf(diff.reshape([np.prod(diff.shape[:2]), diff.shape[-1]])).reshape(diff.shape[:2])
    logprop_weights = log(location_weights/location_weights.sum())[:, np.newaxis]
    if not contr_var: 
        return logsumexp(lpdfs + logprop_weights, 0)
    #time_m1 = np.hstack([time0[:,:-1],time0[:,-1:]])
    else:
        time0 = lpdfs + logprop_weights + log(len(location_weights))
        
        if variant == 1:
            time1 = np.hstack([time0[:,1:],time0[:,:1]])
            cov = np.mean(time0**2-time0*time1)
            var = np.mean((time0-time1)**2)
            lpdfs = lpdfs  -    cov/var * (time0-time1)        
            return logsumexp(lpdfs - log(len(location_weights)), 0)
        elif variant == 2:
            cvar = (time0[:,:,np.newaxis] - 
                    np.dstack([np.hstack([time0[:, 1:], time0[:, :1]]),
                               np.hstack([time0[:,-1:], time0[:,:-1]])]))

            
            ## self-covariance matrix of control variates
            K_cvar = np.diag(np.mean(cvar**2, (0, 1)))
            #add off diagonal
            K_cvar = K_cvar + (1.-np.eye(2)) * np.mean(cvar[:,:,0]*cvar[:,:,1])
            
            ## covariance of control variates with random variable
            cov = np.mean(time0[:,:,np.newaxis] * cvar, 0).mean(0)
            
            optimal_comb = np.linalg.inv(K_cvar) @ cov
            lpdfs = lpdfs  -  cvar @ optimal_comb
            return logsumexp(lpdfs - log(len(location_weights)), 0)
    def plot_deep_gp_2d(ax,params,plot_xs):
        ax.cla()
        rs = npr.RandomState(0)

        sampled_means_and_covs = [sample_mean_cov_from_deep_gp(params, plot_xs) for i in xrange(n_samples)]
        sampled_means, sampled_covs = zip(*sampled_means_and_covs)
        avg_pred_mean = np.mean(sampled_means, axis = 0)
        avg_pred_cov = np.mean(sampled_covs, axis = 0)
        #print("X*",avg_pred_mean)
        #rint("X*",plot_xs[0:4])

        #sampled_means_and_covs_orig = [sample_mean_cov_from_deep_gp(params, X) for i in xrange(n_samples)]
        #sampled_means_orig, sampled_covs_orig = zip(*sampled_means_and_covs_orig)
        #avg_pred_mean_orig = np.mean(sampled_means_orig, axis = 0)
        #print("Orignal Xs",avg_pred_mean_orig)

        X0 = params[5:5+num_pseudo_params*2].reshape(num_pseudo_params,2)
        y0 = params[5+num_pseudo_params*2:5+num_pseudo_params*3]
        #ax.scatter(X0[:,0],X0[:,1],c = y0)

        avg_pred_mean = avg_pred_mean.reshape(40,40)
        ax.contourf(np.linspace(-1,1,40),np.linspace(-1,1,40), avg_pred_mean)
        ax.scatter(X[:,0],X[:,1],c=y)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title("Full Deep GP")
Example #6
0
def GetVideoTimeSeries():
    TS = []
    VIS_RY = []
    VIS_RX = []
    t0 = None
    lasthsum = None
    lastvsum = None
    hprior, vprior = 0, 0
    for msg in parse.ParseLog(open('../rustlerlog-BMPauR')):
        if msg[0] == 'img':
            _, ts, im = msg
            im = GammaCorrect(im)
            hsum = np.sum(im, axis=0)
            hsum -= np.mean(hsum)
            vsum = np.sum(im, axis=1)
            vsum -= np.mean(vsum)
            if t0 is None:
                t0 = ts
                lasthsum = hsum
                lastvsum = vsum
            hoffset = np.argmax(
                -2*np.arange(-80 - hprior, 81 - hprior)**2 +
                np.correlate(lasthsum, hsum[80:-80], mode='valid')) - 80
            voffset = np.argmax(
                -2*np.arange(-60 - vprior, 61 - vprior)**2 +
                np.correlate(lastvsum, vsum[60:-60], mode='valid')) - 60
            TS.append(ts - t0)
            VIS_RY.append(hoffset)
            VIS_RX.append(voffset)
            hprior, vprior = hoffset, voffset
            lasthsum = hsum
            lastvsum = vsum
    return TS, VIS_RY, VIS_RX
Example #7
0
def BBVI(params,num_samples,num_particles,K,convergence):
    m = np.array([0.,0.])
    v = np.array([0.,0.])
    lower_bounds = []
    iterating = 1
    i=0
    i_true = 0
    while iterating==1:
        params,m,v,LB = iterate(params,num_samples,num_particles,i,m,v)
        i_true += 1
        i+=1
        lower_bounds.append(-LB)
        if params[0]<=0 or params[1]<=0:
            i=0
            params = np.random.uniform(10,100,2)
            m = np.array([0.,0.])
            v = np.array([0.,0.])
        if i%100==0:
            print params
        if len(lower_bounds)>K+1:
            lb2= np.mean(np.array(lower_bounds[-K:])/n)
            lb1 = np.mean(np.array(lower_bounds[-K-1:-1])/n)
            if abs(lb2-lb1)<convergence:
                iterating = 0
    return params, lower_bounds, i_true,all_gradients
Example #8
0
def avabc(params,num_samples,num_particles,K,convergence):
    lower_bounds = []
    scaled_lower_bounds = []
    iterating = 1
    i=0
    m = np.array([0.,0.])
    v = np.array([0.,0.])
    while iterating==1:
        params,m,v,LB = iterate(params,i,m,v,num_samples,num_particles)
        #LB/=M
        if params[1]<=0 or np.isnan(params).any():
            params = np.random.uniform(0,1,2)
            m = np.array([0.,0.])
            v = np.array([0.,0.])
        i+=1
        lower_bounds.append(-LB)
        if len(lower_bounds)>K+1:
            lb2 = np.mean(np.array(lower_bounds[-K:]))
            lb1 = np.mean(np.array(lower_bounds[-K-1:-1]))
            scaled_lower_bounds.append(-LB)
            if abs(lb2-lb1)<convergence:
                iterating = 0
            if i%10==0:
                print abs(lb2-lb1)
            if np.isnan(abs(lb2-lb1)):
                lower_bounds=[]
        if i%10==0:
            print params, LB
    return params, scaled_lower_bounds,i
 def get_error_and_ll(w, v_prior, X, y, K, location, scale):
     v_noise = np.exp(parser.get(w, 'log_v_noise')[ 0, 0 ]) * scale**2
     q = get_parameters_q(w, v_prior)
     samples_q = draw_samples(q, K)
     outputs = predict(samples_q, X) * scale + location
     log_factor = -0.5 * np.log(2 * math.pi * v_noise) - 0.5 * (np.tile(y, (1, K)) - np.array(outputs))**2 / v_noise
     ll = np.mean(logsumexp(log_factor - np.log(K), 1))
     error = np.sqrt(np.mean((y - np.mean(outputs, 1, keepdims = True))**2))
     return error, ll
 def print_perf(gen_params, dsc_params, iter, gen_gradient, dsc_gradient):
     if iter % 10 == 0:
         ability = np.mean(objective(gen_params, dsc_params, iter))
         fake_data = generate_from_noise(gen_params, 20, noise_dim, seed)
         real_data = train_images[batch_indices(iter)]
         probs_fake = np.mean(sigmoid(neural_net_predict(dsc_params, fake_data)))
         probs_real = np.mean(sigmoid(neural_net_predict(dsc_params, real_data)))
         print("{:15}|{:20}|{:20}|{:20}".format(iter//num_batches, ability, probs_fake, probs_real))
         save_images(fake_data, 'gan_samples.png', vmin=0, vmax=1)
Example #11
0
    def train_epoch(self, network):
        self._setup(network)
        losses = []

        X_batch = batch_iterator(network.X, network.batch_size)
        y_batch = batch_iterator(network.y, network.batch_size)
        for X, y in tqdm(zip(X_batch, y_batch), "Epoch progress"):
            loss = np.mean(network.update(X, y))
            self.update(network)
            losses.append(loss)
        epoch_loss = np.mean(losses)
        return epoch_loss
 def callback(weights, iter):
     if iter % 10 == 0:
         print "max of weights", np.max(np.abs(weights))
         train_preds = undo_norm(pred_fun(weights, train_smiles))
         cur_loss = loss_fun(weights, train_smiles, train_targets)
         training_curve.append(cur_loss)
         print "Iteration", iter, "loss", cur_loss, "train RMSE", \
             np.sqrt(np.mean((train_preds - train_raw_targets)**2)),
         if validation_smiles is not None:
             validation_preds = undo_norm(pred_fun(weights, validation_smiles))
             print "Validation RMSE", iter, ":", \
                 np.sqrt(np.mean((validation_preds - validation_raw_targets) ** 2)),
def train_loss(wb_vect, unflattener, cv=False, batch=True, batch_size=10,
               debug=False):
    """
    Training loss is MSE.

    We pass in a flattened parameter vector and its unflattener.
    """
    wb_struct = unflattener(wb_vect)

    if batch:
        batch_size = batch_size
    else:
        batch_size = len(graphs)

    if cv:
        samp_graphs, samp_inputs = batch_sample(test_graphs,
                                                input_shape,
                                                batch_size=batch_size)
    else:
        samp_graphs, samp_inputs = batch_sample(graphs,
                                                input_shape,
                                                batch_size=batch_size)

    print('batch size: {0}'.format(len(samp_graphs)))
    preds = predict(wb_struct, samp_inputs, samp_graphs)
    graph_ids = [g.graph['seqid'] for g in samp_graphs]
    graph_scores = drug_data.set_index('seqid').ix[graph_ids]['FPV'].values.\
        reshape(preds.shape)

    assert preds.shape == graph_scores.shape

    mse = np.mean(np.power(preds - graph_scores, 2))

    if debug:
        print(graph_ids)
        print('Predictions:')
        print(preds)
        print('Mean: {0}'.format(np.mean(preds)))
        print('')
        print('Actual')
        print(graph_scores)
        print('Mean: {0}'.format(np.mean(graph_scores)))
        print('')
        print('Difference')
        print(preds - graph_scores)
        print('Mean Squared Error: {0}'.format(mse))
        print('')

    return mse
Example #14
0
def magcal_residual(X, a, mb):
    """ residual from all observations given magnetometer eccentricity, bias,
    gyro bias, and gyro scale"""

    # (x-c)T A^T A (x-c) = 1
    # x^T Ax - 2x^T Ac + c^T Ac = 1

    # a b c | x' = ax + by + cz
    # 0 d e | y' = dy + ez
    # 0 0 f | z' = fz
    # z = 1/f z'
    # y = 1/d (y' - e/f z')
    # x = 1/a (x' - b/d(y' - e/f z') - c/f z')
    #   = 1/a (x' - b/d y' - (be/df - c/f) z')
    # (x-c) A^T A (x-c)
    # [(A x) - (A c)]^2 - 1 = 0

    # y = A(x-c)
    # y /= ||y||
    # q(x; A, c) = (A^-1 (y+c) - x)^2

    Y = np.dot(X - mb, Ainv(a)).T
    Y /= np.linalg.norm(Y, axis=0)
    # Y /= np.sqrt(np.sum(np.square(Y), axis=0))
    Y = np.dot(Y.T, Amatrix(a)) + mb
    return np.mean(np.sum(np.square(X - Y), axis=1))
    def elbo(params, t):
        '''
        samples: [n_samples, D]
        u: [D,1]
        w: [D,1]
        b: [1]
        '''

        mean = params[0]
        log_std = params[1]
        u = params[2]
        w = params[3]
        b = params[4]

        samples = sample_diag_gaussian(mean, log_std, num_samples, rs)
        z_k = normalizing_flows(samples, u, w, b)

        logp_zk = logprob(z_k)
        logp_zk = np.reshape(logp_zk, [num_samples, 1])

        logq_zk = variational_log_density(params, samples)
        logq_zk = np.reshape(logq_zk, [num_samples, 1])

        elbo = logp_zk - logq_zk
  
        return np.mean(elbo) #over samples
Example #16
0
 def variational_objective(params, t):
     """Provides a stochastic estimate of the variational lower bound."""
     mean, log_std,inputs, len_sc, variance = unpack_params(params)
     samples = rs.randn(num_samples, D) * np.exp(log_std) + mean
     print(log_std)
     lower_bound = gaussian_entropy(log_std) + np.mean(logprob(samples,inputs,len_sc,variance, t))
     return -lower_bound
Example #17
0
def normalize_array(A):
    mean, std = np.mean(A), np.std(A)
    A_normed = (A - mean) / std
    def restore_function(X):
        return X * std + mean

    return A_normed, restore_function
    def elbo(params, t):
        '''
        samples: [n_samples, D]
        u: [D,1]
        w: [D,1]
        b: [1]
        '''

        beta = t/100 + .001

        if beta > .99:
            beta = 1.

        beta = 1

        mean = params[0]
        log_std = params[1]
        norm_flow_params = params[2]

        samples = sample_diag_gaussian(mean, log_std, n_samples, rs)
        z_k, all_zs = normalizing_flows(samples, norm_flow_params)

        logp_zk = logprob(z_k)
        logp_zk = np.reshape(logp_zk, [n_samples, 1])

        logq_zk = variational_log_density(params, samples)
        logq_zk = np.reshape(logq_zk, [n_samples, 1])

        elbo = (beta*logp_zk) - logq_zk 
  
        return np.mean(elbo) #over samples
Example #19
0
def H_i(samples,params,n,k,i,num_particles):
    H_i = 0
    S = len(samples)
    c = c_i(params,n,k,i,S,num_particles)
    inner = (h_s(samples,n,k,num_particles)+c)*gradient_log_recognition(params,samples,i)
    H_i = np.mean(inner)
    return H_i
Example #20
0
def KL_via_sampling(params,a2,b2,U):
    a1 = params[0]
    b1 = params[1]
    theta = generate_kumaraswamy(params,U)
    E = np.log(kumaraswamy_pdf(theta,params)/kumaraswamy_pdf(theta,np.array([a2,b2])))
    E = np.mean(E)
    return E
Example #21
0
def meddistance(X, subsample=None, mean_on_fail=True):
    """
    Compute the median of pairwise distances (not distance squared) of points
    in the matrix.  Useful as a heuristic for setting Gaussian kernel's width.

    Parameters
    ----------
    X : n x d numpy array
    mean_on_fail: True/False. If True, use the mean when the median distance is 0.
        This can happen especially, when the data are discrete e.g., 0/1, and 
        there are more slightly more 0 than 1. In this case, the m

    Return
    ------
    median distance
    """
    if subsample is None:
        D = dist_matrix(X, X)
        Itri = np.tril_indices(D.shape[0], -1)
        Tri = D[Itri]
        med = np.median(Tri)
        if med <= 0:
            # use the mean
            return np.mean(Tri)
        return med

    else:
        assert subsample > 0
        rand_state = np.random.get_state()
        np.random.seed(9827)
        n = X.shape[0]
        ind = np.random.choice(n, min(subsample, n), replace=False)
        np.random.set_state(rand_state)
        # recursion just one
        return meddistance(X[ind, :], None, mean_on_fail)
Example #22
0
File: ml.py Project: gablg1/ml-util
def normalizeFeatures(X_train, X_test):
    mean_X_train = np.mean(X_train, 0)
    std_X_train = np.std(X_train, 0)
    std_X_train[ std_X_train == 0 ] = 1
    X_train_normalized = (X_train - mean_X_train) / std_X_train
    X_test_normalized = (X_test - mean_X_train) / std_X_train
    return X_train_normalized, X_test_normalized
Example #23
0
def KL_via_sampling(params,mu2,sigma2,U):
    mu = params[0]
    sigma = params[1]
    theta = generate_lognormal(params,U)
    E = np.log(lognormal_pdf(theta,params)/lognormal_pdf(theta,np.array([mu2,sigma2])))
    E = np.mean(E)
    return E
Example #24
0
def fit_mog(data, max_comps = 20, mog_class = MixtureOfGaussians):
    from sklearn import mixture
    N            = data.shape[0]
    if len(data.shape) == 1:
        train = data[:int(.75 * N)]
        test  = data[int(.75 * N):]
    else:
        train = data[:int(.75*N), :]
        test  = data[int(.75*N):, :]

    # do train/val GMM fit
    num_comps = np.arange(1, max_comps+1)
    scores    = np.zeros(len(num_comps))
    for i, num_comp in enumerate(num_comps):
        g = mixture.GMM(n_components=num_comp, covariance_type='full')
        g.fit(train)
        logprobs, res = g.score_samples(test)
        scores[i] = np.mean(logprobs)
        print "num_comp = %d (of %d) score = %2.4f"%(num_comp, max_comps, scores[i])
    print "best validation, num_comps = %d"%num_comps[scores.argmax()]

    # fit final model to all data
    g = mixture.GMM(n_components = num_comps[scores.argmax()], covariance_type='full')
    g.fit(data)

    # create my own GMM object - it's better!
    return mog_class(g.means_, g.covars_, g.weights_)
    def print_perf(combined_params, iter, grad):
        if iter % 10 == 0:
            gen_params, rec_params = combined_params
            bound = np.mean(objective(combined_params, iter))
            print("{:15}|{:20}".format(iter//num_batches, bound))

            fake_data = generate_from_noise(gen_params, 20, latent_dim, seed)
            save_images(fake_data, 'vae_samples.png', vmin=0, vmax=1)
 def loss_fun(weights, smiles, targets):
     fingerprint_weights, net_weights = unpack_weights(weights)
     fingerprints = fingerprint_func(fingerprint_weights, smiles)
     net_loss = net_loss_fun(net_weights, fingerprints, targets)
     if len(fingerprint_weights) > 0 and fp_l2_penalty > 0:
         return net_loss + fp_l2_penalty * np.mean(fingerprint_weights**2)
     else:
         return net_loss
Example #27
0
File: glm.py Project: onenoc/lfvbae
def KL_via_sampling(params,eps):
    theta = params[0]+np.exp(params[1])*eps
    muPrior = 0
    sigmaPrior = 1
    paramsPrior = np.array([muPrior,sigmaPrior])
    E = np.log(normal_pdf(theta,params)/normal_pdf(theta,paramsPrior))
    E = np.mean(E)
    return E
Example #28
0
def grad_expectation(params,n,k,U1,U2,v):
    E=0
    theta = generate_kumaraswamy(params,U2)
    grad_kuma_pdf = grad(kumaraswamy_pdf)
    f=likelihood(n,k,theta,i)
    g = grad_kuma_pdf(theta,params)
    E = f*g
    return np.mean(E)
    def plot_deep_gp(ax, params, plot_xs):
        ax.cla()
        rs = npr.RandomState(0)
        
        sampled_means_and_covs = [sample_mean_cov_from_deep_gp(params, plot_xs, rs = rs, with_noise = False, FITC = False) for i in xrange(50)]
        sampled_means, sampled_covs = zip(*sampled_means_and_covs)
        avg_pred_mean = np.mean(sampled_means, axis = 0)
        avg_pred_cov = np.mean(sampled_covs, axis = 0)

        sampled_means_and_covs_2 = [sample_mean_cov_from_deep_gp(params, plot_xs, rs = rs, with_noise = False, FITC = False) for i in xrange(n_samples_to_plot)]
        sampled_funcs = np.array([rs.multivariate_normal(mean, cov*(random)) for mean,cov in sampled_means_and_covs_2])
        ax.plot(plot_xs,sampled_funcs.T)
        ax.plot(X, y, 'kx')
        ax.plot(plot_xs,avg_pred_mean,'r--')
        #ax.set_ylim([-1.5,1.5])
        ax.set_xticks([])
        ax.set_yticks([])
Example #30
0
def fit_maxlike(data, r_guess):
    # follows Wikipedia's section on negative binomial max likelihood
    assert np.var(data) > np.mean(data), "Likelihood-maximizing parameters don't exist!"
    loglike = lambda r, p: np.sum(negbin_loglike(r, p, data))
    p = lambda r: np.sum(data) / np.sum(r+data)
    rprime = lambda r: grad(loglike)(r, p(r))
    r = newton(rprime, r_guess)
    return r, p(r)
Example #31
0
                                inputs=data['inputs'],
                                targets=data['one_hot_targets'],
                                hps=hps)
        params = update_params(params, gradients_0, hps['learning_rates'])

    for epoch in range(num_epochs - 1):
        # for epoch in range(100):
        gradients_1 = loss_grad(params,
                                inputs=data['inputs'],
                                targets=data['one_hot_targets'],
                                hps=hps)
        hps['learning_rates'] = update_lr(gradients_0, gradients_1,
                                          hps['learning_rates'],
                                          hps['hyper_learning_rate'])
        params = update_params(params, gradients_1, hps['learning_rates'])

        gradients_0 = copy.deepcopy(gradients_1)

    print(
        'loss after training: ',
        loss(params,
             inputs=data['inputs'],
             targets=data['one_hot_targets'],
             hps=hps))

    print(
        np.mean(
            np.equal(
                np.argmax(forward(params, inputs=data['inputs'], hps=hps)[-1],
                          axis=1), data['labels_indexed'])))
Example #32
0
def mse(y_pred, y_true):
    return np.mean(np.power(y_true-y_pred, 2))
Example #33
0
    def initialize(self,
                   base_model,
                   datas,
                   inputs=None,
                   masks=None,
                   tags=None,
                   num_em_iters=50,
                   num_tr_iters=50):

        print("Initializing...")
        print("First with FA using {} steps of EM.".format(num_em_iters))
        fa, xhats, Cov_xhats, lls = factor_analysis_with_imputation(
            self.D, datas, masks=masks, num_iters=num_em_iters)

        if self.D == 1 and base_model.transitions.__class__.__name__ == "DDMTransitions":

            d_init = np.mean([y[0:3] for y in datas], axis=(0, 1))
            u_sum = np.array([np.sum(u) for u in inputs])
            y_end = np.array([y[-3:] for y in datas])
            u_l, u_u = np.percentile(
                u_sum, [20, 80])  # use 20th and 80th percentile input
            y_U = y_end[np.where(u_sum >= u_u)]
            y_L = y_end[np.where(u_sum <= u_l)]
            C_init = (1.0 / 2.0) * np.mean(
                (np.mean(y_U, axis=0) - np.mean(y_L, axis=0)), axis=0)

            self.Cs = C_init.reshape([1, self.N, self.D])
            self.ds = d_init.reshape([1, self.N])
            self.inv_etas = np.log(fa.sigmasq).reshape([1, self.N])

        else:

            # define objective
            Td = sum([x.shape[0] for x in xhats])

            def _objective(params, itr):
                new_datas = [np.dot(x, params[0].T) + params[1] for x in xhats]
                obj = base_model.log_likelihood(new_datas, inputs=inputs)
                return -obj / Td

            # initialize R and r
            R = 0.1 * np.random.randn(self.D, self.D)
            r = 0.01 * np.random.randn(self.D)
            params = [R, r]

            print(
                "Next by transforming latents to match AR-HMM prior using {} steps of max log likelihood."
                .format(num_tr_iters))
            state = None
            lls = [-_objective(params, 0) * Td]
            pbar = trange(num_tr_iters)
            pbar.set_description("Epoch {} Itr {} LP: {:.1f}".format(
                0, 0, lls[-1]))

            for itr in pbar:
                params, val, g, state = sgd_step(value_and_grad(_objective),
                                                 params, itr, state)
                lls.append(-val * Td)
                pbar.set_description("LP: {:.1f}".format(lls[-1]))
                pbar.update(1)

            R = params[0]
            r = params[1]

            # scale x's to be max at 1.1
            for d in range(self.D):
                x_transformed = [(np.dot(x, R.T) + r)[:, d] for x in xhats]
                max_x = np.max(x_transformed)
                R[d, :] *= 1.1 / max_x
                r[d] *= 1.1 / max_x

            self.Cs = (fa.W @ np.linalg.inv(R)).reshape([1, self.N, self.D])
            self.ds = fa.mean - fa.W @ np.linalg.inv(R) @ r
            self.inv_etas = np.log(fa.sigmasq).reshape([1, self.N])