class TestFolderSource(TestCase):
    def setUp(self):
        self.stock_id = TEST_STOCK_ID
        self.source = ReportLocalData(DEFAULT_FOLDER)
        self.expected_path = DEFAULT_FOLDER + '/600096.gz'

    def tearDown(self):
        if os.path.exists(DEFAULT_FOLDER):
            shutil.rmtree(DEFAULT_FOLDER)

    def test_initial_file_source(self):
        expect_folder = self.source.source_folder

        self.assertEqual(DEFAULT_FOLDER, expect_folder)
        # folder should be auto created
        self.assertTrue(os.path.exists(expect_folder))

    def test_load_data(self):
        data = read_sina_600096_test_data()
        data.to_csv(self.expected_path, compression="gzip")

        expected_data = self.source.load_data(self.stock_id)
        assert_frame_equal(data, expected_data)

    def test_load_data_empty(self):
        expected_data = self.source.load_data(self.stock_id)

        self.assertTrue(expected_data.empty)
Example #2
0
def create_china_stock_income_repository(local_source=None,
                                         base_folder="reportData"
                                         ) -> ChinaIncomeRepository:
    if local_source is None:
        local_source = ReportLocalData(base_folder +
                                       "/greenseer/china_income_reports")

    return ChinaIncomeRepository(local_source)
Example #3
0
def set_local_path(local_path: str):
    _repository.refresh(local_path)
    global _local_all_reports_repo
    _local_all_reports_repo = ReportLocalData(local_path + "/chinaReports")
Example #4
0
        return self._cash

    @property
    def stock_info(self) -> pd.DataFrame:
        return self._stock_info.load_data()

    @property
    def local_path(self) -> str:
        return self._local_path


_repository = ChinaReportRepository()

_logger = logging.getLogger()

_local_all_reports_repo = ReportLocalData(DEFAULT_LOCAL_PATH + "/chinaReports")

_ALL_REPORTS_NAME = "all_finance_reports"


def stock_info(repository=_repository) -> pd.DataFrame:
    return repository.stock_info


def set_local_path(local_path: str):
    _repository.refresh(local_path)
    global _local_all_reports_repo
    _local_all_reports_repo = ReportLocalData(local_path + "/chinaReports")


def load_train_data(train_size=10,
 def setUp(self):
     self.stock_id = TEST_STOCK_ID
     self.source = ReportLocalData(DEFAULT_FOLDER)
     self.expected_path = DEFAULT_FOLDER + '/600096.gz'