class TestVerifyThatPredictionsArePrecient(unittest.TestCase):
    def setUp(self):
        self.start_date = dt.datetime(2004,9,1)
        self.end_date = dt.datetime(2010,1,1)
        self.sp_limited = StockPeriod('GOOG_SHORT', self.start_date, self.end_date)
        self.tf_limited = TechnicalFeatures(self.sp_limited)
        
        self.lim_start = dt.datetime(2004,9,1,16)
        self.lim_end = dt.datetime(2009,12,31,16)

        self.mlearn_lim = MachineLearner('GOOG_SHORT', feature_gen = lambda sp: self.tf_limited, stock_period = self.sp_limited)
        

        self.sp = StockPeriod('GOOG', self.start_date, dt.datetime(2011,1,1))
        self.tf = TechnicalFeatures(self.sp)
        
        self.mlearn_verify = MachineLearner('GOOG_', feature_gen = lambda sp: self.tf, stock_period = self.sp)
        
    
    def assert_limits(self, sp, tf, start_d, end_d):
        self.assertEqual(sp.close_data.index[0], start_d)
        self.assertEqual(sp.close_data.index[-1], end_d)
        self.assertEqual(tf.relative_data['close'].index[0], start_d)
        self.assertEqual(tf.relative_data['close'].index[-1], end_d)

    def test_correct_boundaries(self):
        self.assert_limits(self.sp_limited, self.tf_limited, self.lim_start, self.lim_end)
        self.assert_limits(self.mlearn_lim.sp, self.mlearn_lim.feats, self.lim_start, self.lim_end)
        
    def test_predictions_are_the_same(self):
        #predictions after end date
        features = self.tf.get_features()[:][self.lim_end:][1:100]
        self.assertTrue(features.index[0] > self.lim_end)
        
        self.mlearn_lim.learn_period(dt.datetime(2006,1,1), dt.datetime(2008,1,1))
        self.assert_limits(self.mlearn_lim.sp, self.mlearn_lim.feats, self.lim_start, self.lim_end)
        
        data_holder = []
        for index, feat_row in features.iterrows():
            res = self.mlearn_lim.predict([feat_row.values])
            data_holder.append(res[0])
        
        df_test_result = pd.DataFrame(index = features.index, data = data_holder)
        print df_test_result

        self.mlearn_verify.learn_period(dt.datetime(2006,1,1), dt.datetime(2008,1,1))
        
        df_verify_result = self.mlearn_verify.predict_period(features.index[0], features.index[-1])
        print df_verify_result

        self.assertTrue(np.equal(df_test_result.index, df_verify_result.index).all())
        self.assertTrue(np.equal(df_test_result.values, df_verify_result.values).all())
        
        

    def test_features_equal(self):
        limited_feats = self.tf_limited.get_features()
        feats = self.tf.get_features()
        
        sub_set_limited = limited_feats[:][-50:]
        #print sub_set_limited
        self.assertEqual(sub_set_limited.index[-1], self.lim_end)
        sub_set_full = feats[:][sub_set_limited.index[0]:sub_set_limited.index[-1]]
        #print sub_set_full
        self.assertTrue(feats.index[-1] > self.lim_end)
        #print np.equal(sub_set_limited.values, sub_set_full.values)
        self.assertTrue(np.equal(sub_set_limited.values, sub_set_full.values).all())