class DataLoaderArgument(RelatedConfigMixin): # MAYBE - make this a general argument class doc = related.StringField("", required=False) example = AnyField(required=False) default = AnyField(required=False) name = related.StringField(required=False) type = related.StringField(default='str', required=False) optional = related.BooleanField(default=False, required=False) tags = StrSequenceField(str, default=[], required=False) # TODO - restrict the tags def __attrs_post_init__(self): if self.doc == "": logger.warn("doc empty for one of the dataloader `args` fields") # parse args self.example = recursive_dict_parse(self.example, 'url', RemoteFile.from_config) self.default = recursive_dict_parse(self.default, 'url', RemoteFile.from_config)
class ModelDescription(RelatedLoadSaveMixin): """Class representation of model.yaml """ args = related.ChildField(dict) info = related.ChildField(ModelInfo) schema = related.ChildField(ModelSchema) defined_as = related.StringField(required=False) type = related.StringField(required=False) default_dataloader = AnyField(default='.', required=False) postprocessing = related.ChildField(dict, default=OrderedDict(), required=False) dependencies = related.ChildField(Dependencies, default=Dependencies(), required=False) test = related.ChildField(ModelTest, default=ModelTest(), required=False) path = related.StringField(required=False) # TODO - add after loading validation for the arguments class? def __attrs_post_init__(self): if self.defined_as is None and self.type is None: raise ValueError("Either defined_as or type need to be specified") # load additional objects for k in self.postprocessing: k_observed = k if k == 'variant_effects': k = 'kipoi_veff' if is_installed(k): # Load the config properly if the plugin is installed try: parser = get_model_yaml_parser(k) self.postprocessing[k_observed] = parser.from_config( self.postprocessing[k_observed]) object.__setattr__(self, "postprocessing", self.postprocessing) except Exception: logger.warn( "Unable to parse {} filed in ModelDescription: {}". format(k_observed, self)) # parse args self.args = recursive_dict_parse(self.args, 'url', RemoteFile.from_config) # parse default_dataloader if isinstance(self.default_dataloader, dict): self.default_dataloader = DataLoaderImport.from_config( self.default_dataloader)
class ModelTest(RelatedLoadSaveMixin): # predictions = related. expect = AnyField(default=None, required=False) precision_decimal = related.IntegerField(default=7, required=False) # Arrays should be almost equal to `precision_decimal` places # https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.testing.assert_almost_equal.html # abs(desired-actual) < 1.5 * 10**(-precision_decimal) def __attrs_post_init__(self): if self.expect is not None: if not isinstance(self.expect, str): # it has to be the url if not (isinstance(self.expect, dict) and "url" in self.expect): raise ValueError("expect is not a file path, expecting a url field with entries: url and md5") self.expect = RemoteFile.from_config(self.expect)
class TaskSpec(RelatedConfigMixin): # Bigwig file paths to counts from # the positive and negative strands task = related.StringField() pos_counts = AnyField() neg_counts = AnyField() peaks = related.StringField(None, required=False) # if True the profile array will be single-stranded ignore_strand = related.BooleanField(False, required=False) # bias_model = related.StringField(None, required=False) # if available, provide the bias model # implements .predict_on_batch(onehot_seq) # bias_bigwig = related.StringField(None, required=False) def load_counts(self, intervals, use_strand=True, progbar=False): import numpy as np # from genomelake.extractors import BigwigExtractor from basepair.extractors import StrandedBigWigExtractor if isinstance(self.pos_counts, str): pos_counts = StrandedBigWigExtractor( self.pos_counts, use_strand=use_strand, nan_as_zero=True).extract(intervals) neg_counts = StrandedBigWigExtractor( self.neg_counts, use_strand=use_strand, nan_as_zero=True).extract(intervals) elif isinstance(self.pos_counts, list): pos_counts = sum([ StrandedBigWigExtractor(counts, use_strand=use_strand, nan_as_zero=True).extract( intervals, progbar=progbar) for counts in self.pos_counts ]) neg_counts = sum([ StrandedBigWigExtractor(counts, use_strand=use_strand, nan_as_zero=True).extract( intervals, progbar=progbar) for counts in self.neg_counts ]) else: raise ValueError('pos_counts is not a str or a list') if self.ignore_strand: return (pos_counts + neg_counts)[..., np.newaxis] # keep the same dimension else: if use_strand: neg_strand = np.array([s.strand == '-' for s in intervals]).reshape((-1, 1)) return np.stack([ np.where(neg_strand, neg_counts, pos_counts), np.where(neg_strand, pos_counts, neg_counts) ], axis=-1) else: return np.stack([pos_counts, neg_counts], axis=-1) def get_bw_dict(self): return {"pos": self.pos_counts, "neg": self.neg_counts} def touch_all_files(self, verbose=True): from basepair.utils import touch_file if isinstance(self.pos_counts, str): touch_file(self.pos_counts, verbose) touch_file(self.neg_counts, verbose) elif isinstance(self.pos_counts, list): for counts in self.pos_counts: touch_file(counts, verbose) for counts in self.neg_counts: touch_file(counts, verbose) else: raise ValueError('pos_counts is not a str or a list') def abspath(self): """Use absolute filepaths """ if self.peaks is None: peaks_abspath = None else: peaks_abspath = os.path.abspath(self.peaks) obj = deepcopy(self) if isinstance(self.pos_counts, str): obj.pos_counts = os.path.abspath(self.pos_counts) obj.neg_counts = os.path.abspath(self.neg_counts) elif isinstance(self.pos_counts, list): obj.pos_counts = [ os.path.abspath(counts) for counts in self.pos_counts ] obj.neg_counts = [ os.path.abspath(counts) for counts in self.neg_counts ] else: raise ValueError('pos_counts is not a str or a list') obj.peaks = peaks_abspath return obj def __attrs_post_init__(self): if not isinstance(self.pos_counts, str) and not isinstance( self.pos_counts, list): raise ValueError('pos_counts is not a str or a list') if type(self.neg_counts) != type(self.pos_counts): raise ValueError('neg_counts has to be same type as pos_counts') if isinstance(self.pos_counts, list) and len(self.pos_counts) != len(self.neg_counts): raise ValueError('neg_counts has to be same length as pos_counts')