def test_load_model(do_load_model, filename, local_dir, extension, remote_dir, copy_from_remote_to_local, always_fetch_remote): global mock_local_load_model_state global mock_remote_load_model_state load_model(do_load_model, filename, local_dir, extension, remote_dir, copy_from_remote_to_local, always_fetch_remote) assert mock_local_load_model_state[ "model"] == mock_remote_load_model_state["model"]
print_data_bunch(data_bunch_source) # ### Load Model # ##### Get custom loss function if config["init"]["get_loss_function"]["name"] == "get_custom_loss": loss = get_loss_function(**config["exec"]["get_loss_function"]["params"]) custom_objects = {loss.__name__: loss} else: custom_objects = None print(custom_objects) model = tasks.load_model(load_model=load_model, copy_from_remote_to_local=copy_from_remote_to_local, custom_objects=custom_objects, **config["exec"]["load_model"]) # ### Save (formatted) config tasks.store_artifacts(store_artifact_locally, copy_from_local_to_remote, config, **config["exec"]["save_formatted_config"]["params"]) print("Config stored with following parameters") print_dict(config["exec"]["save_formatted_config"]["params"]) # ### Save Session # ##### Save session info