예제 #1
0
    def test_add_extensions_metadata_2(self):
        """Test that add_extensions_metadata doesn't add extensions that are not used.

        In this case we will use a config containing torch, but we will make_component
        on torch so that it can be compiled. After that, we add_extensions_metadata with
        torch, which is a valid extensions for the config (redundant, but valid).
        
        """
        TORCH_TAG_PREFIX = "torch"
        make_component(torch.nn.Module,
                       TORCH_TAG_PREFIX,
                       only_module='torch.nn')

        config = """
        !torch.Linear
          in_features: 2
          out_features: 2
        """

        schema = yaml.load(config)
        schema.add_extensions_metadata({"torch": "torch"})
        assert schema._extensions == {"torch": "torch"}

        mixed_ext = TestSerializationExtensions.EXTENSIONS.copy()
        mixed_ext.update({"torch": "torch"})
        schema.add_extensions_metadata(mixed_ext)
        assert schema._extensions == {"torch": "torch"}
예제 #2
0
 def test_save_to_file_and_load_from_file_roundtrip_complex_nontorch_root(
         self, complex_multi_layered_nontorch_root, pickle_only,
         compress_save_file):
     TORCH_TAG_PREFIX = "torch"
     make_component(torch.nn.Module,
                    TORCH_TAG_PREFIX,
                    only_module='torch.nn')
     old_obj = complex_multi_layered_nontorch_root(from_config=True)
     state = old_obj.get_state()
     with tempfile.TemporaryDirectory() as root_path:
         path = os.path.join(root_path, 'savefile.flambe')
         save_state_to_file(state, path, compress_save_file, pickle_only)
         list_files(path)
         if pickle_only:
             path += '.pkl'
         if compress_save_file:
             path += '.tar.gz'
         state_loaded = load_state_from_file(path)
         check_mapping_equivalence(state, state_loaded)
         check_mapping_equivalence(state._metadata, state_loaded._metadata)
     new_obj = complex_multi_layered_nontorch_root(from_config=True)
     int_state = new_obj.get_state()
     new_obj.load_state(state_loaded, strict=False)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata, new_state._metadata)
     check_mapping_equivalence(int_state._metadata, state_loaded._metadata)
예제 #3
0
 def test_save_to_file_and_load_from_file_roundtrip_complex(
         self, complex_multi_layered):
     TORCH_TAG_PREFIX = "torch"
     make_component(torch.nn.Module,
                    TORCH_TAG_PREFIX,
                    only_module='torch.nn')
     old_obj = complex_multi_layered(from_config=True)
     # Test that the current state is actually saved, for a
     # Component-only child of torch objects
     old_obj.child.child.child.x = 24
     state = old_obj.get_state()
     with tempfile.TemporaryDirectory() as path:
         save_state_to_file(state, path)
         list_files(path)
         state_loaded = load_state_from_file(path)
         check_mapping_equivalence(state, state_loaded)
         # assert False
     new_obj = complex_multi_layered(from_config=True)
     new_obj.load_state(state_loaded, strict=False)
     old_state = old_obj.get_state()
     new_state = new_obj.get_state()
     check_mapping_equivalence(new_state, old_state)
     check_mapping_equivalence(old_state._metadata,
                               new_state._metadata,
                               exclude_config=False)
예제 #4
0
    def test_state_complex_multilayered_nontorch_root(self, complex_multi_layered_nontorch_root):
        TORCH_TAG_PREFIX = "torch"
        make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn')

        obj = complex_multi_layered_nontorch_root(from_config=True)
        t1 = obj.item.child.linear.weight
        state = obj.get_state()
        new_obj = complex_multi_layered_nontorch_root(from_config=True)
        new_obj.load_state(state)
        t2 = new_obj.item.child.linear.weight
        assert t1.equal(t2)
        check_mapping_equivalence(new_obj.get_state(), obj.get_state())
        check_mapping_equivalence(obj.get_state(), new_obj.get_state())
예제 #5
0
    def test_add_extensions_metadata_3(self,
                                       complex_multi_layered_nontorch_root):
        """Test that add_extensions_metadata doesn't add extensions that are not used

        In this case we will use a config containing torch, but we will make_component
        on torch so that it can be compiled. After that, we add_extensions_metadata with
        torch, which is a valid extensions for the config (redundant, but valid).

        """
        TORCH_TAG_PREFIX = "torch"
        exclude = ['torch.nn.quantized', 'torch.nn.qat']
        make_component(torch.nn.Module,
                       TORCH_TAG_PREFIX,
                       only_module='torch.nn',
                       exclude=exclude)

        schema = complex_multi_layered_nontorch_root(from_config=True,
                                                     schema=True)
        schema.add_extensions_metadata({"torch": "torch"})

        # This method asserts recursively that torch is added to extensions when
        # there is a subcomponent that uses torch.
        # It returns if at least one component with torch was found, that should
        # always happen based on the complex_multi_layered_nontorch_root.
        def helper(data):
            found = False
            if isinstance(data, Schema):
                if data.component_subclass.__module__.startswith("torch."):
                    found = True
                    assert data._extensions == {"torch": "torch"}

                for val in data.keywords.values():
                    f = helper(val)
                    if f:
                        found = f

            elif isinstance(data, Mapping):
                for val in data.values():
                    f = helper(val)
                    if f:
                        found = f
            return found

        assert helper(schema)
예제 #6
0
def main(args: argparse.Namespace) -> None:
    """Execute command based on given config"""
    if is_dev_mode():
        print(cl.RA(ASCII_LOGO_DEV))
        print(cl.BL(f"Location: {get_flambe_repo_location()}\n"))
    else:
        print(cl.RA(ASCII_LOGO))
        print(cl.BL(f"VERSION: {flambe.__version__}\n"))

    # Pass original module for ray / pickle
    make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn')
    # torch.optim.Optimizer exists, ignore mypy
    make_component(
        torch.optim.Optimizer,
        TORCH_TAG_PREFIX,  # type: ignore
        only_module='torch.optim')
    make_component(torch.optim.lr_scheduler._LRScheduler,
                   TORCH_TAG_PREFIX,
                   only_module='torch.optim.lr_scheduler')
    make_component(ray.tune.schedulers.TrialScheduler, TUNE_TAG_PREFIX)
    make_component(ray.tune.suggest.SearchAlgorithm, TUNE_TAG_PREFIX)

    # TODO check first if there is cluster as if there is there
    # is no need to install extensions
    check_system_reqs()
    with SafeExecutionContext(args.config) as ex:
        if args.cluster is not None:
            with SafeExecutionContext(args.cluster) as ex_cluster:
                cluster, _ = ex_cluster.preprocess(
                    secrets=args.secrets, install_ext=args.install_extensions)
                runnable, extensions = ex.preprocess(import_ext=False,
                                                     secrets=args.secrets)
                cluster.run(force=args.force)
                if isinstance(runnable, ClusterRunnable):
                    cluster = cast(Cluster, cluster)

                    # This is independant to the type of ClusterRunnable
                    destiny = os.path.join(cluster.get_orch_home_path(),
                                           "extensions")

                    # Before sending the extensions, they need to be
                    # downloaded (locally).
                    t = os.path.join(FLAMBE_GLOBAL_FOLDER, "extensions")
                    extensions = download_extensions(extensions, t)

                    # At this point, all remote extensions
                    # (except pypi extensions)
                    # have local paths.
                    new_extensions = cluster.send_local_content(extensions,
                                                                destiny,
                                                                all_hosts=True)

                    new_secrets = cluster.send_secrets()

                    # Installing the extensions is crutial as flambe
                    # will execute without '-i' flag and therefore
                    # will assume that the extensions are installed
                    # in the orchestrator.
                    cluster.install_extensions_in_orchestrator(new_extensions)
                    logger.info(cl.GR("Extensions installed in Orchestrator"))

                    runnable.setup_inject_env(cluster=cluster,
                                              extensions=new_extensions,
                                              force=args.force)
                    cluster.execute(runnable, new_extensions, new_secrets,
                                    args.force)
                else:
                    raise ValueError(
                        "Only ClusterRunnables can be executed in a cluster.")
        else:
            runnable, _ = ex.preprocess(secrets=args.secrets,
                                        install_ext=args.install_extensions)
            runnable.run(force=args.force, verbose=args.verbose)