def wrap(cls): if inspect.isfunction(cls): raise ValueError( "Function-based dataloader are currently not supported with kipoi_dataloader decorator" ) # figure out the right dataloader type dl_type_inferred = infer_parent_class(cls, AVAILABLE_DATALOADERS) if dl_type_inferred is None: raise ValueError( "Dataloader needs to inherit from one of the available dataloaders {}" .format(list(AVAILABLE_DATALOADERS))) # or not inherits_from(cls, Dataset) doc = cls.__doc__ doc = textwrap.dedent(doc) # de-indent if not re.match("^defined_as: ", doc): doc = "defined_as: {}\n".format(cls.__name__) + doc if not re.match("^type: ", doc): doc = "type: {}\n".format(dl_type_inferred) + doc # parse the yaml yaml_dict = related.from_yaml(doc) dl_descr = DataLoaderDescription.from_config(yaml_dict) # override parameters for k, v in six.iteritems(override): rsetattr(dl_descr, k, v) # setup optional parameters arg_names, default_values = _get_arg_name_values(cls) if set(dl_descr.args) != set(arg_names): raise ValueError( "Described args don't exactly match the implemented arguments" "docstring: {}, actual: {}".format(list(dl_descr.args), list(arg_names))) # properly set optional / non-optional argument values for i, arg in enumerate(dl_descr.args): optional = i >= len(arg_names) - len(default_values) if dl_descr.args[arg].optional and not optional: logger.warning( "Parameter {} was specified as optional. However, there " "are no defaults for it. Specifying it as not optinal". format(arg)) dl_descr.args[arg].optional = optional dl_descr.info.name = cls.__name__ # enrich the class with dataloader description return cls._add_description_factory(dl_descr)
def test_ModelDescription(): for rc_support in [True, False]: seq_string_shape = "" if rc_support: ssrs = supports_simple_rc_str else: ssrs = "" model = ModelDescription.from_config( from_yaml(model_yaml % (seq_string_shape, ssrs))) dataloader = DataLoaderDescription.from_config( from_yaml(dataloader_yaml % (seq_string_shape))) mi = ModelInfoExtractor(model, dataloader) assert mi.use_seq_only_rc == rc_support assert all([ isinstance(mi.seq_input_mutator[sl], OneHotSequenceMutator) for sl in ["seq_a", "seq_c"] ]) assert all([ isinstance(mi.seq_input_mutator[sl], DNAStringSequenceMutator) for sl in ["seq_b"] ]) assert all([ mi.seq_input_metadata[sl] == "ranges" for sl in ["seq_a", "seq_b"] ]) assert all( [mi.seq_input_metadata[sl] == "ranges_b" for sl in ["seq_c"]]) assert all([ isinstance(mi.seq_input_array_trafo[sl], ReshapeDna) for sl in ["seq_a", "seq_c"] ]) assert all([ isinstance(mi.seq_input_array_trafo[sl], ReshapeDnaString) for sl in ["seq_b"] ]) # Test whether the model infor extractor also works without missing shapes in the dataloader schema definition. model = ModelDescription.from_config( from_yaml(model_yaml % (seq_string_shape, ssrs))) dataloader = DataLoaderDescription.from_config( from_yaml(dataloader_yaml_noshapes)) mi = ModelInfoExtractor(model, dataloader) assert mi.use_seq_only_rc == rc_support assert all([ isinstance(mi.seq_input_mutator[sl], OneHotSequenceMutator) for sl in ["seq_a", "seq_c"] ]) assert all([ isinstance(mi.seq_input_mutator[sl], DNAStringSequenceMutator) for sl in ["seq_b"] ]) assert all([ mi.seq_input_metadata[sl] == "ranges" for sl in ["seq_a", "seq_b"] ]) assert all( [mi.seq_input_metadata[sl] == "ranges_b" for sl in ["seq_c"]]) assert all([ isinstance(mi.seq_input_array_trafo[sl], ReshapeDna) for sl in ["seq_a", "seq_c"] ]) assert all([ isinstance(mi.seq_input_array_trafo[sl], ReshapeDnaString) for sl in ["seq_b"] ])