Exemplo n.º 1
0
def main():
    result_wandb = convert_args(sys.argv)
    with open("config.json", "w+") as f:
        r = requests.get(os.environ["BASE_CONFIG_PATH"])
        f.write(r.text)
    the_config = make_config(result_wandb, "config.json")
    print(the_config)
    train_function("PyTorch", the_config)
Exemplo n.º 2
0
def loop_through(data_dir: str,
                 interrmittent_gcs: bool = False,
                 use_transfer: bool = True,
                 start_index: int = 0,
                 end_index: int = 25) -> None:
    """
    Function that makes and executes a set of config files
    This is since we have over 9k files and.
    """
    if not os.path.exists("model_save"):
        os.mkdir("model_save")
    sorted_dir_list = sorted(os.listdir(data_dir))
    # total = len(sorted_dir_list)
    for i in range(start_index, end_index):
        file_name = sorted_dir_list[i]
        station_id_gage = file_name.split("_flow.csv")[0]
        res = split_on_letter(station_id_gage)
        gage_id = res[0]
        station_id = res[1]
        file_path_name = os.path.join(data_dir, file_name)
        print("Training on: " + file_path_name)
        correct_file = None
        if use_transfer and len(os.listdir("model_save")) > 1:
            weight_files = filter(lambda x: x.endswith(".pth"),
                                  os.listdir("model_save"))
            paths = []
            for weight_file in weight_files:
                paths.append(os.path.join("model_save", weight_file))
            correct_file = max(paths, key=os.path.getctime)
            print(correct_file)
        config = make_config_file(file_path_name, gage_id, station_id,
                                  correct_file)
        extension = ".json"
        file_name_json = station_id + "config_f" + extension
        with open(file_name_json, "w+") as f:
            json.dump(config, f)
        try:
            train_function("PyTorch", config)
        except Exception as e:
            print("An exception occured for: " + file_name_json)
            traceback.print_exc()
            print(e)
 def test_decoder_multi_step(self):
     if "save_path" in self.model_params:
         del self.model_params["save_path"]
     forecast_model = train_function("PyTorch", self.model_params)
     t = torch.Tensor([3, 4, 5]).repeat(1, 336, 1)
     output = simple_decode(forecast_model.model,
                            torch.ones(1, 5, 3),
                            336,
                            t,
                            output_len=3)
     # We want to check for leakage
     self.assertFalse(3 in output[:, :, 0])
 def test_multivariate_single_step(self):
     # dumb error fixes
     if "save_path" in self.model_params3:
         del self.model_params["save_path"]
     t = torch.Tensor([3, 6, 5]).repeat(1, 100, 1)
     forecast_model3 = train_function("PyTorch", self.model_params3)
     output = simple_decode(forecast_model3.model,
                            torch.ones(1, 5, 3),
                            100,
                            t,
                            output_len=3,
                            multi_targets=2)
     self.assertFalse(3 in output)
     self.assertFalse(6 in output)
Exemplo n.º 5
0
    temp_training_data_dir = data_dir / "temp"
    temp_training_data_dir.mkdir(exist_ok=True)

    query = "region_identifier=='{}'"
    for region in italian_regions:
        region_df, dataset_length, file_path = format_corona_data(
            df.query(query.format(region)), region, temp_training_data_dir)
        sweep_id = wandb.sweep(generate_wandb_sweep_config(
            f"Multivariate TS sweep new cases, mobility, weather -- {region}"),
                               project="covid-forecast")
        wandb.agent(
            sweep_id, lambda: train_function(
                "PyTorch",
                generate_training_config(
                    str(file_path),
                    feature_columns=[
                        'retail_recreation', 'grocery_pharmacy', 'parks',
                        'transit_stations', 'workplaces', 'residential',
                        'avg_temperature', 'min_temperature',
                        'max_temperature', 'relative_humidity',
                        'specific_humidity', 'pressure', "new_cases"
                    ],
                    target_column=["new_cases"],
                    df_len=dataset_length)))

    logger.info("done")
    shutil.rmtree(temp_training_data_dir)
    shutil.rmtree("wandb")
    shutil.rmtree("model_save")