def main():

    logger.info("Creating vocabulary dictionary...")
    vocab = Dictionary.from_corpus(train_data, unk='<unk>')
    logger.info("Creating tag dictionary...")
    vocab_tags = Dictionary.from_corpus_tags(train_data, unk='<unk>')
    vocab.add_word('<s>')
    vocab.add_word('</s>')
    V = vocab.size()

    vocab_tags.add_word('<s>')
    vocab_tags.add_word('</s>')
    V_tag = vocab_tags.size()

    feature_matrix = np.zeros((vocab_tags.size(), vocab_tags.num_sub_tags))
    feature_matrix[(0, 0)] = 1  # unk encoding

    for tag, tag_id in vocab_tags:
        if tag == "<s>":
            feature_matrix[(tag_id, 1)] = 1
        elif tag == "</s>":
            feature_matrix[(tag_id, 2)] = 1
        else:
            for sub_tag in vocab_tags.map_tag_to_sub_tags[tag]:
                val = vocab_tags.map_sub_to_ids[sub_tag]
                feature_matrix[(tag_id, val)] = 1

    Q = cPickle.load(open(sys.argv[4], 'rb'))

    print "START COMPARING"

    word = sys.argv[5]
    word_id = vocab.lookup_id(word)

    words = []
    for j, q in enumerate(Q):
        words.append((j, vocab.lookup_word(j), cosine(Q[word_id], q)))
        words.sort(key=lambda x: x[2])
    print words[:20]
def train_lbl(train_data, dev_data, test_data=[], 
              K=20, context_sz=2, learning_rate=1.0, 
              rate_update='simple', epochs=10, 
              batch_size=1, rng=None, patience=None, 
              patience_incr=2, improvement_thrs=0.995, 
              validation_freq=1000):

    """ Train log-bilinear model """
    # create vocabulary from train data, plus <s>, </s>
    
    logger.info("Creating vocabulary dictionary...")
    vocab = Dictionary.from_corpus(train_data, unk='<unk>')
    logger.info("Creating tag dictionary...")
    vocab_tags = Dictionary.from_corpus_tags(train_data, unk='<unk>')
    vocab.add_word('<s>')
    vocab.add_word('</s>')
    V = vocab.size()

    vocab_tags.add_word('<s>')
    vocab_tags.add_word('</s>')
    V_tag = vocab_tags.size()
    #print train_data
    
    # initialize random generator if not provided
    rng = np.random.RandomState() if not rng else rng
    
    logger.info("Making instances...")
    # generate (context, target) pairs of word ids
    train_set_x, train_set_y, train_set_tags = make_instances(train_data, vocab, vocab_tags, context_sz)
    dev_set_x, dev_set_y, dev_set_tags  = make_instances(dev_data, vocab, vocab_tags, context_sz)
    test_set_x, test_set_y, test_set_tags  = make_instances(test_data, vocab, vocab_tags, context_sz)
    
    # make feature_matrix 
    # very sparse matrix...better way to do it?
    feature_matrix = np.zeros((vocab_tags.size(),vocab_tags.num_sub_tags))
    feature_matrix[(0,0)] = 1 # unk encoding
    
    for tag,tag_id in vocab_tags:
        if tag == "<s>":
            feature_matrix[(tag_id,1)] = 1
        elif tag == "</s>":
            feature_matrix[(tag_id,2)] = 1
        else:
            for sub_tag in vocab_tags.map_tag_to_sub_tags[tag]:
                val = vocab_tags.map_sub_to_ids[sub_tag]
                feature_matrix[(tag_id,val)] = 1
             
    feature_matrix[1,:] = np.zeros((vocab_tags.num_sub_tags))
    # number of minibatches for training
    n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
    n_dev_batches = dev_set_x.get_value(borrow=True).shape[0] / batch_size
    n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size

    # build the model
    logger.info("Build the model ...")
    index = T.lscalar()
    
    x = T.imatrix('x')
    y = T.ivector('y')
    t = T.ivector('t') # the tag vector
    
    # create log-bilinear model
    lbl = LogBilinearLanguageModel(x, V, K, vocab_tags.num_sub_tags, feature_matrix, context_sz, rng)
 

    # cost function is negative log likelihood of the training data
    cost = lbl.negative_log_likelihood(y,t)
  
    # compute the gradient
    gparams = []
    for param in lbl.params:
        gparam = T.grad(cost, param)
        gparams.append(gparam)

    # specify how to update the parameter of the model
    updates = []
    for param_i,(param, gparam) in enumerate(zip(lbl.params, gparams)):
        updates.append((param, param-learning_rate*gparam))
                        
    # function that computes log-probability of the dev set
    logprob_dev = theano.function(inputs=[index], outputs=cost,
                                  givens={x: dev_set_x[index*batch_size:
                                                           (index+1)*batch_size],
                                          y: dev_set_y[index*batch_size:
                                                           (index+1)*batch_size],
                                          t: dev_set_tags[index*batch_size:(index+1)*batch_size]
                                          })


    # function that computes log-probability of the test set
    logprob_test = theano.function(inputs=[index], outputs=cost,
                                   givens={x: test_set_x[index*batch_size:
                                                             (index+1)*batch_size],
                                           y: test_set_y[index*batch_size:
                                                             (index+1)*batch_size],
                                           t: test_set_tags[index*batch_size:(index+1)*batch_size]
                                       })
    
    # function that returns the cost and updates the parameter 
    train_model = theano.function(inputs=[index], outputs=cost,
                                  updates=updates,
                                  givens={x: train_set_x[index*batch_size:
                                                             (index+1)*batch_size],
                                          y: train_set_y[index*batch_size:
                                                             (index+1)*batch_size],
                                          t: train_set_tags[index*batch_size:(index+1)*batch_size]
                                          })


    # perplexity functions
    def compute_dev_logp():
        return np.mean([logprob_dev(i) for i in xrange(n_dev_batches)])

    def compute_test_logp():
        return np.mean([logprob_test(i) for i in xrange(n_test_batches)])

    def ppl(neg_logp):
        return np.power(2.0, neg_logp)
    
    # train model
    logger.info("training model...")
    best_params = None
    last_epoch_dev_ppl = np.inf
    best_dev_ppl = np.inf
    test_ppl = np.inf
    test_core = 0
    start_time = time.clock()
    done_looping = False

    for epoch in xrange(epochs):
        if done_looping:
            break
        logger.info('epoch %i' % epoch) 
        for minibatch_index in xrange(n_train_batches):
            itr = epoch * n_train_batches + minibatch_index
            train_logp = train_model(minibatch_index)
            logger.info('epoch %i, minibatch %i/%i, train minibatch log prob %.4f ppl %.4f' % 
                         (epoch, minibatch_index+1, n_train_batches, 
                          train_logp, ppl(train_logp)))
            if (itr+1) % validation_freq == 0:
                # compute perplexity on dev set, lower is better
                dev_logp = compute_dev_logp()
                dev_ppl = ppl(dev_logp)
                logger.debug('epoch %i, minibatch %i/%i, dev log prob %.4f ppl %.4f' % 
                             (epoch, minibatch_index+1, n_train_batches, 
                              dev_logp, ppl(dev_logp)))
                # if we got the lowest perplexity until now
                if dev_ppl < best_dev_ppl:
                    # improve patience if loss improvement is good enough
                    if patience and dev_ppl < best_dev_ppl * improvement_thrs:
                        patience = max(patience, itr * patience_incr)
                    best_dev_ppl = dev_ppl
                    test_logp = compute_test_logp()
                    test_ppl = ppl(test_logp)
                    logger.debug('epoch %i, minibatch %i/%i, test log prob %.4f ppl %.4f' % 
                                 (epoch, minibatch_index+1, n_train_batches, 
                                  test_logp, ppl(test_logp)))
            # stop learning if no improvement was seen for a long time
            if patience and patience <= itr:
                done_looping = True
                break
        # adapt learning rate
        if rate_update == 'simple':
            # set learning rate to 1 / (epoch+1)
            learning_rate = 1.0 / (epoch+1)
        elif rate_update == 'adaptive':
            # half learning rate if perplexity increased at end of epoch (Mnih and Teh 2012)
            this_epoch_dev_ppl = ppl(compute_dev_logp())
            if this_epoch_dev_ppl > last_epoch_dev_ppl:
                learning_rate /= 2.0
            last_epoch_dev_ppl = this_epoch_dev_ppl
        elif rate_update == 'constant':
            # keep learning rate constant
            pass
        else:
            raise ValueError("Unknown learning rate update strategy: %s" %rate_update)
        
    end_time = time.clock()
    total_time = end_time - start_time
    logger.info('Optimization complete with best dev ppl of %.4f and test ppl %.4f' % 
                (best_dev_ppl, test_ppl))
    logger.info('Training took %d epochs, with %.1f epochs/sec' % (epoch+1, 
                float(epoch+1) / total_time))
    logger.info("Total training time %d days %d hours %d min %d sec." % 
                (total_time/60/60/24, total_time/60/60%24, total_time/60%60, total_time%60))
    # return model
    return lbl