def __init__(self):
        self.aan = aanmeta()
        self.cnw = CitationNetwork()

        self.prestige_features = {}
        self.position_features = {}
        self.content_features = {}
        self.style_features = {}

        self.load_features()
    def main(self):

        n = int(sys.argv[1])
        diff = sys.argv[2] if len(sys.argv) > 2 else 3

        init = 1980
        last = 2006

        aan = aanmeta()
        all_papers = aan.get_restricted_papers(init, last)

        till_n = [p for p in all_papers if p.year <= n]

        training = random.sample(till_n, int(math.ceil(0.8 * len(till_n))))
        model_test = [i for i in till_n if i not in training]
        test = [i for i in all_papers if i.year == (n + diff)]

        print "Total files: %d" % (len(training) + len(model_test) + len(test))

        self.feats = {}
        featfile = open("1980_2006.pruned_feats", "r")
        for line in featfile:
            line = line.strip()
            [pid, featstr] = line.split("\t")
            self.feats[pid] = featstr.split("<>")

        training_fname = "experiment_files/1980_%s.train.txt" % n
        model_test_fname = "experiment_files/1980_%s.modeltest.txt" % n
        test_fname = "experiment_files/%s.test.txt" % (n + diff)

        self.write_data(training, training_fname)
        self.write_data(model_test, model_test_fname)
        self.write_data(test, test_fname)

        # creating the response files
        self.cnw = CitationNetwork()
        training_resp_file = open("experiment_files/1980_%s.train.resp.txt" % n, "w")
        model_test_resp_file = open("experiment_files/1980_%s.modeltest.resp.txt" % n, "w")
        test_resp_file = open("experiment_files/%s.test.resp.txt" % (n + diff), "w")
        self.write_response(training, training_resp_file, n)
        self.write_response(model_test, model_test_resp_file, n)
        self.write_response(test, test_resp_file, n)

        # write the time step files
        ts_file = open("experiment_files/%s_%s_timesteps.txt" % (init, n), "w")
        for pid in [i.pid for i in all_papers if i.year >= init and i.year <= n]:
            ts_file.write("%s\t%d\n" % (pid, get_year_from_id(pid)))
from nltk.corpus import stopwords
from utils import get_ft_ngrams

import time

from aanmeta import aanmeta
from citation_nw import CitationNetwork

# Given a reference year and forecast year
# start from 1980 and get all titles till reference 
# Extract 1,2 and 3 grams from the titles and create feature files for training data

lyear = 1980
hyear = 2006

aan = aanmeta()
refpapers = aan.get_restricted_papers(1980, 2006)

# # refpapers.sort(key=lambda x: x.year)

ngram_sep = "_"

feat_file_name = "%s_%s.pruned_feats" % (lyear, hyear)
feat_file = open(feat_file_name, 'w')
# resp_file = open(resp_file_name, 'w')

stopwords = stopwords.words('english')
featfreq = {}
initfeats = {}

print "Total %d papers" % len(refpapers)