def ksne_obj(params,X,K,P,pijs,count_pot_dict,T, fast=True):
    N = X.shape[0]
    Z = params.reshape((N, P))
    tdijs = squareform(pdist(Z, 'sqeuclidean'))  # tilde dijs
    from numpy import log
    tdijs = -log(1+tdijs)

    term1 = np.sum(tdijs * pijs)  # diagonal is 0

    if fast:
        from chain.chain_alg_wrapper import ChainAlg
        chain = ChainAlg(N, Kmin=K, Kmax=K, minibatch_size = N)

        mask = np.ones((N,N)) - np.diag(np.ones(N))
        log_mask = -np.diag(np.ones(N))*1e100



        pot = tdijs + log_mask
        #exp_pot = mask * np.exp(tdijs)
        marginals, samples, logZs = chain.infer(pot)
        sum_log_Z = logZs.sum()

    else:
        sum_log_Z = 0
        mask = np.ones(N)
        for i in xrange(N):
            if i > 0:  mask[i-1] = 1
            mask[i] = 0
    
            node_margs, count_margs, log_Z = \
                sc.conv_tree(mask * np.exp(tdijs[i,:]), count_pot_dict, T, use_fft=False)
            sum_log_Z += log_Z
    
    return term1 - sum_log_Z
def ksne_grad_correct(params,X,K,P,pijs,count_pot_dict,T):
    Z = params.reshape((N, P))
    tdijs = -squareform(pdist(Z, 'sqeuclidean'))  # tilde dijs

    grad = pijs.copy()

    mask = np.ones(N)
    for i in xrange(N):
        if i > 0:  mask[i-1] = 1
        mask[i] = 0

        node_margs, count_margs, log_Z = \
            sc.conv_tree(mask * np.exp(tdijs[i,:]), count_pot_dict, T, use_fft=False)

        grad[i,:] -= node_margs

    g2 = 2*(grad+grad.T)
    dZ = g2.dot(Z)-g2.sum(1)[:,np.newaxis]*Z

    return dZ
def train_sgd(X,Y,dataset_name,P,K,Z=None,num_iters=1,perplexity=5,batch_size=10,eta=0.1,mo=0,L2=0,class_to_keep=[],seed=1,distance='sqeuclidean',text_labels=None):
    """
X, Y: points and their respective labels 
dataset_name: the name given to this dataset
P: the dimensionality of the embedding space
K: the number of neighbors 
Z: coordinates of the points in embedded space
num_iters: number of passes over the set of points  
perplexity: perplexity of the distribution over datapoints. Check: 
            'Visualizing Data using t-SNE by van der Maaten and Hinton in JMLR (09) 2008', 
            for details. 
batch_size: size of mini batches
eta: learning rate
mo: momentum
L2: size of L2 regularizer
class_to_keep: subset of labels that have been kept (for naming results and figure files)
seed: the value of the seed (for naming results and figure files)
distance: distance metric {sqeuclidean, emd}
test_labels: alternative labels used in plots
    """

    N = X.shape[0]   # num points
    D = X.shape[1]   # input dimension

    ##################################################################################
    if Z is None:  Z = np.random.randn(N,P)*0.1  # points in embedded space

    if P == 2: # plot if embedding is in 2D
        plot_embedding(Z,Y,K,eta,mo,0,target_entropy=perplexity,iters=0,tot_iters=num_iters,class_to_keep=class_to_keep,seed=seed,dataset_name=dataset_name,text_labels=text_labels)

    if distance == 'sqeuclidean': 
        dijs = -squareform(pdist(X, 'sqeuclidean'))
    elif distance == 'emd': # earth mover's distance
        from emd.emddist import emddist
        dijs = -emddist(X)
    dijs = dist_normalize(dijs,Y,perplexity=perplexity,dataset_name=dataset_name) 
    dijs -= HUGE_VAL * np.eye(N)
    pijs = np.zeros((N,N))

    T = sc.make_balanced_binary_tree(N)
    root_idx = N + T.shape[0]-1
    print "Root idx", root_idx

    global_count_potential = np.zeros(N+1)
    global_count_potential[K] = 1
    only_global_constraint_dict = {}
    only_global_constraint_dict[root_idx] = global_count_potential

    # Precompute E_i[y_j] (or pijs)
    print 'precomute E_i[y_j]'

    from chain.chain_alg_wrapper import ChainAlg
    chain = ChainAlg(N, Kmin=K, Kmax=K, minibatch_size = N)
    pot = dijs
    #exp_pot = np.exp(dijs)
    marginals, samples, logZs = chain.infer(pot)
    pijs = marginals


    debug = False
    if debug:
        pijs_debug = pijs.copy()
        for i in xrange(N):
            thetas = dijs[i,:]
            #print np.exp(thetas)
            node_margs, count_margs, log_Z = \
                sc.conv_tree(np.exp(thetas), only_global_constraint_dict, T, use_fft=False)
            pijs_debug[i,:] = node_margs

        diff = abs(pijs - pijs_debug).max() 
        assert diff < 1e-8
        print 'precomputation of marginals is correct!'


    ##################################################################################

    num_batches = np.ceil(np.double(N)/batch_size)
    randIndices = np.random.permutation(X.shape[0])

    print 'num_batches = %s'  % num_batches

    dZ = np.zeros(Z.shape)
    V = dZ*0
    for i in range(num_iters):
        f_tot = 0
        for batch in range(int(num_batches)):
            print "iteration " + str(i+1) + " batch " + str(batch+1) + " of " + str(int(num_batches))
            ind = randIndices[np.mod(range(batch*batch_size,(batch+1)*batch_size),X.shape[0])]
            f_tot += ksne_obj(Z.flatten(),X,K,P,pijs,only_global_constraint_dict,T)


            g = ksne_grad((Z+V*mo).flatten(),X,K,P,pijs,only_global_constraint_dict,T).reshape(Z.shape)

            #dZ = mo*dZ - eta*g
            V = V*mo + eta*(g-L2*(Z+V*mo))
            #Z -= dZ
            Z += V
        print 'objective: %s, |g|=%s' % (str(f_tot/num_batches), abs(g).mean())
        if P == 2: # plot if embedding is in 2D
            plot_embedding(Z,Y,K,eta,mo,f_tot/num_batches,target_entropy=perplexity,iters=i+1,tot_iters=num_iters,class_to_keep=class_to_keep,seed=seed,dataset_name=dataset_name,text_labels=text_labels)

    return Z