def test_add_extensions_metadata_2(self): """Test that add_extensions_metadata doesn't add extensions that are not used. In this case we will use a config containing torch, but we will make_component on torch so that it can be compiled. After that, we add_extensions_metadata with torch, which is a valid extensions for the config (redundant, but valid). """ TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') config = """ !torch.Linear in_features: 2 out_features: 2 """ schema = yaml.load(config) schema.add_extensions_metadata({"torch": "torch"}) assert schema._extensions == {"torch": "torch"} mixed_ext = TestSerializationExtensions.EXTENSIONS.copy() mixed_ext.update({"torch": "torch"}) schema.add_extensions_metadata(mixed_ext) assert schema._extensions == {"torch": "torch"}
def test_save_to_file_and_load_from_file_roundtrip_complex_nontorch_root( self, complex_multi_layered_nontorch_root, pickle_only, compress_save_file): TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') old_obj = complex_multi_layered_nontorch_root(from_config=True) state = old_obj.get_state() with tempfile.TemporaryDirectory() as root_path: path = os.path.join(root_path, 'savefile.flambe') save_state_to_file(state, path, compress_save_file, pickle_only) list_files(path) if pickle_only: path += '.pkl' if compress_save_file: path += '.tar.gz' state_loaded = load_state_from_file(path) check_mapping_equivalence(state, state_loaded) check_mapping_equivalence(state._metadata, state_loaded._metadata) new_obj = complex_multi_layered_nontorch_root(from_config=True) int_state = new_obj.get_state() new_obj.load_state(state_loaded, strict=False) old_state = old_obj.get_state() new_state = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata) check_mapping_equivalence(int_state._metadata, state_loaded._metadata)
def test_save_to_file_and_load_from_file_roundtrip_complex( self, complex_multi_layered): TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') old_obj = complex_multi_layered(from_config=True) # Test that the current state is actually saved, for a # Component-only child of torch objects old_obj.child.child.child.x = 24 state = old_obj.get_state() with tempfile.TemporaryDirectory() as path: save_state_to_file(state, path) list_files(path) state_loaded = load_state_from_file(path) check_mapping_equivalence(state, state_loaded) # assert False new_obj = complex_multi_layered(from_config=True) new_obj.load_state(state_loaded, strict=False) old_state = old_obj.get_state() new_state = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)
def test_state_complex_multilayered_nontorch_root(self, complex_multi_layered_nontorch_root): TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') obj = complex_multi_layered_nontorch_root(from_config=True) t1 = obj.item.child.linear.weight state = obj.get_state() new_obj = complex_multi_layered_nontorch_root(from_config=True) new_obj.load_state(state) t2 = new_obj.item.child.linear.weight assert t1.equal(t2) check_mapping_equivalence(new_obj.get_state(), obj.get_state()) check_mapping_equivalence(obj.get_state(), new_obj.get_state())
def test_add_extensions_metadata_3(self, complex_multi_layered_nontorch_root): """Test that add_extensions_metadata doesn't add extensions that are not used In this case we will use a config containing torch, but we will make_component on torch so that it can be compiled. After that, we add_extensions_metadata with torch, which is a valid extensions for the config (redundant, but valid). """ TORCH_TAG_PREFIX = "torch" exclude = ['torch.nn.quantized', 'torch.nn.qat'] make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn', exclude=exclude) schema = complex_multi_layered_nontorch_root(from_config=True, schema=True) schema.add_extensions_metadata({"torch": "torch"}) # This method asserts recursively that torch is added to extensions when # there is a subcomponent that uses torch. # It returns if at least one component with torch was found, that should # always happen based on the complex_multi_layered_nontorch_root. def helper(data): found = False if isinstance(data, Schema): if data.component_subclass.__module__.startswith("torch."): found = True assert data._extensions == {"torch": "torch"} for val in data.keywords.values(): f = helper(val) if f: found = f elif isinstance(data, Mapping): for val in data.values(): f = helper(val) if f: found = f return found assert helper(schema)
def main(args: argparse.Namespace) -> None: """Execute command based on given config""" if is_dev_mode(): print(cl.RA(ASCII_LOGO_DEV)) print(cl.BL(f"Location: {get_flambe_repo_location()}\n")) else: print(cl.RA(ASCII_LOGO)) print(cl.BL(f"VERSION: {flambe.__version__}\n")) # Pass original module for ray / pickle make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') # torch.optim.Optimizer exists, ignore mypy make_component( torch.optim.Optimizer, TORCH_TAG_PREFIX, # type: ignore only_module='torch.optim') make_component(torch.optim.lr_scheduler._LRScheduler, TORCH_TAG_PREFIX, only_module='torch.optim.lr_scheduler') make_component(ray.tune.schedulers.TrialScheduler, TUNE_TAG_PREFIX) make_component(ray.tune.suggest.SearchAlgorithm, TUNE_TAG_PREFIX) # TODO check first if there is cluster as if there is there # is no need to install extensions check_system_reqs() with SafeExecutionContext(args.config) as ex: if args.cluster is not None: with SafeExecutionContext(args.cluster) as ex_cluster: cluster, _ = ex_cluster.preprocess( secrets=args.secrets, install_ext=args.install_extensions) runnable, extensions = ex.preprocess(import_ext=False, secrets=args.secrets) cluster.run(force=args.force) if isinstance(runnable, ClusterRunnable): cluster = cast(Cluster, cluster) # This is independant to the type of ClusterRunnable destiny = os.path.join(cluster.get_orch_home_path(), "extensions") # Before sending the extensions, they need to be # downloaded (locally). t = os.path.join(FLAMBE_GLOBAL_FOLDER, "extensions") extensions = download_extensions(extensions, t) # At this point, all remote extensions # (except pypi extensions) # have local paths. new_extensions = cluster.send_local_content(extensions, destiny, all_hosts=True) new_secrets = cluster.send_secrets() # Installing the extensions is crutial as flambe # will execute without '-i' flag and therefore # will assume that the extensions are installed # in the orchestrator. cluster.install_extensions_in_orchestrator(new_extensions) logger.info(cl.GR("Extensions installed in Orchestrator")) runnable.setup_inject_env(cluster=cluster, extensions=new_extensions, force=args.force) cluster.execute(runnable, new_extensions, new_secrets, args.force) else: raise ValueError( "Only ClusterRunnables can be executed in a cluster.") else: runnable, _ = ex.preprocess(secrets=args.secrets, install_ext=args.install_extensions) runnable.run(force=args.force, verbose=args.verbose)