def train_with_sigana(uri_path: str = None): """train model followed by SigAnaRecord Returns ------- pred_score: pandas.DataFrame predict scores performance: dict model performance """ model = init_instance_by_config(CSI300_GBDT_TASK["model"]) dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # start exp with R.start(experiment_name="workflow_with_sigana", uri=uri_path): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = SignalRecord(model, dataset, recorder) sr.generate() pred_score = sr.load("pred.pkl") # predict and calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() ic = sar.load("ic.pkl") ric = sar.load("ric.pkl") uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path
def train_with_sigana(): """train model followed by SigAnaRecord Returns ------- pred_score: pandas.DataFrame predict scores performance: dict model performance """ model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) # start exp with R.start(experiment_name="workflow_with_sigana"): R.log_params(**flatten_dict(task)) model.fit(dataset) # predict and calculate ic and ric recorder = R.get_recorder() sar = SigAnaRecord(recorder, model=model, dataset=dataset) sar.generate() ic = sar.load(sar.get_path("ic.pkl")) ric = sar.load(sar.get_path("ric.pkl")) pred_score = sar.load("pred.pkl") smr = SignalMseRecord(recorder) smr.generate() uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path
def fake_experiment(): """A fake experiment workflow to test uri Returns ------- pass_or_not_for_default_uri: bool pass_or_not_for_current_uri: bool temporary_exp_dir: str """ # start exp default_uri = R.get_uri() current_uri = "file:./temp-test-exp-mag" with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) current_uri_to_check = R.get_uri() default_uri_to_check = R.get_uri() return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
def train_mse(): model = init_instance_by_config(CSI300_GBDT_TASK["model"]) dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = SignalMseRecord(recorder, model=model, dataset=dataset) sr.generate() uri = R.get_uri() return uri
def train_multiseg(): model = init_instance_by_config(CSI300_GBDT_TASK["model"]) dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = MultiSegRecord(model, dataset, recorder) sr.generate(dict(valid="valid", test="test"), True) uri = R.get_uri() return uri