예제 #1
0
 def setUpClass(cls) -> None:
     provider_uri = str(QLIB_DIR.resolve())
     qlib1.init(
         provider_uri=provider_uri,
         expression_cache=None,
         dataset_cache=None,
     )
예제 #2
0
    def setUpClass(cls) -> None:
        # use default data
        provider_uri = "~/.qlib1/qlib_data/cn_data_simple"  # target_dir
        if not exists_qlib_data(provider_uri):
            print(f"Qlib data is not found in {provider_uri}")
            sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
            from get_data import GetData

            GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
        qlib1.init(provider_uri=provider_uri, region=REG_CN)
예제 #3
0
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
    with open(config_path) as fp:
        config = yaml.load(fp, Loader=yaml.Loader)

    # config the `sys` section
    sys_config(config, config_path)

    exp_manager = C["exp_manager"]
    exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
    qlib1.init(**config.get("qlib_init"), exp_manager=exp_manager)

    task_train(config.get("task"), experiment_name=experiment_name)
예제 #4
0
exp_manager = {
    "class": "MLflowExpManager",
    "module_path": "qlib1.workflow.expm",
    "kwargs": {
        "uri": "file:" + exp_path,
        "default_exp_name": "Experiment",
    },
}
if not exists_qlib_data(provider_uri):
    print(f"Qlib data is not found in {provider_uri}")
    sys.path.append(
        str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
    from get_data import GetData

    GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib1.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)


# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
    @functools.wraps(function_to_decorate)
    def _return_wrapped(*args, **kwargs):
        """Internal wrapper function."""
        argspec = inspect.getfullargspec(function_to_decorate)
        valid_names = set(argspec.args + argspec.kwonlyargs)
        if "self" in valid_names:
            valid_names.remove("self")
        for arg_name in kwargs:
            if arg_name not in valid_names:
                raise ValueError("Unknown argument seen '%s', expected: [%s]" %
                                 (arg_name, ", ".join(valid_names)))
예제 #5
0
    def __init__(
        self,
        qlib_dir: str,
        csv_path: str,
        check_fields: str = None,
        freq: str = "day",
        symbol_field_name: str = "symbol",
        date_field_name: str = "date",
        file_suffix: str = ".csv",
        max_workers: int = 16,
    ):
        """

        Parameters
        ----------
        qlib_dir : str
            qlib1 dir
        csv_path : str
            origin csv path
        check_fields : str, optional
            check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
        freq : str, optional
            freq, value from ["day", "1m"]
        symbol_field_name: str, optional
            symbol field name, by default "symbol"
        date_field_name: str, optional
            date field name, by default "date"
        file_suffix: str, optional
            csv file suffix, by default ".csv"
        max_workers: int, optional
            max workers, by default 16
        """
        self.qlib_dir = Path(qlib_dir).expanduser()
        bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
        self.qlib_symbols = sorted(map(lambda x: x.name.lower(),
                                       bin_path_list))
        qlib1.init(
            provider_uri=str(self.qlib_dir.resolve()),
            mount_path=str(self.qlib_dir.resolve()),
            auto_mount=False,
            redis_port=-1,
        )
        csv_path = Path(csv_path).expanduser()
        self.csv_files = sorted(
            csv_path.glob(f"*{file_suffix}") if csv_path.is_dir(
            ) else [csv_path])

        if check_fields is None:
            check_fields = list(
                map(lambda x: x.split(".")[0],
                    bin_path_list[0].glob(f"*.bin")))
        else:
            check_fields = check_fields.split(",") if isinstance(
                check_fields, str) else check_fields
        self.check_fields = list(map(lambda x: x.strip(), check_fields))
        self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
        self.max_workers = max_workers
        self.symbol_field_name = symbol_field_name
        self.date_field_name = date_field_name
        self.freq = freq
        self.file_suffix = file_suffix
예제 #6
0
from qlib1.workflow import R
from qlib1.workflow.record_temp import SignalRecord, PortAnaRecord




if __name__ == "__main__":

    # use default data
    provider_uri = "~/.qlib1/qlib_data/cn_data"  # target_dir
    if not exists_qlib_data(provider_uri):
        print(f"Qlib data is not found in {provider_uri}")
        sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
        GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

    qlib1.init(provider_uri=provider_uri, region=REG_CN)

    market = "csi300"
    benchmark = "SH000300"

    ###################################
    # train model
    ###################################
    data_handler_config = {
        "start_time": "2008-01-01",
        "end_time": "2020-08-01",
        "fit_start_time": "2008-01-01",
        "fit_end_time": "2014-12-31",
        "instruments": market,
    }