def it_returns_required_elems(): userdata = dict(some_key=1) test_s = s( s.is_dict( all_required=True, elems=dict( a=s.is_int(), b=s.is_float(help="A float"), c=s.is_number(), d=s.is_str(userdata=userdata), e=s.is_list(), f=s.is_dict(all_required=True, elems=dict(d=s.is_int(), e=s.is_int())), ), )) reqs = test_s.requirements() assert reqs == [ ("a", int, None, None), ("b", float, "A float", None), ("c", float, None, None), ("d", str, None, userdata), ("e", list, None, None), ("f", dict, None, None), ]
def _before(): nonlocal test_s test_s = s( s.is_dict( all_required=True, elems=dict( a=s.is_int(), b=s.is_int(), c=s.is_dict(all_required=True, elems=dict(d=s.is_int(), e=s.is_int())), ), ))
def it_validates_recursively(): test_s = s( s.is_dict(elems=dict( a=s.is_int(), b=s.is_list(required=True, elems=s.is_str()), c=s.is_dict(required=True), ))) test_s.validate(dict(a=1, b=["a", "b"], c=dict())) with zest.raises(SchemaValidationFailed): test_s.validate(dict(a=1, b=[1], c=dict())) with zest.raises(SchemaValidationFailed): test_s.validate(dict(a=1, b=["a"], c=1))
def it_ignores_underscored_keys(): test_s = s( s.is_dict(elems=dict(a=s.is_int(), b=s.is_int()), no_extras=True)) with zest.raises(SchemaValidationFailed): test_s.validate(dict(a=1, b="str")) with zest.raises(SchemaValidationFailed): test_s.validate(dict(a=1, b=2, _c=[])) test_s = s( s.is_dict( elems=dict(a=s.is_int(), b=s.is_int()), no_extras=True, ignore_underscored_keys=True, )) test_s.validate(dict(a=1, b=2, _c=[]))
def it_shows_help(): schema = s( s.is_kws(a=s.is_dict( help="Help for a", elems=dict( b=s.is_int(help="Help for b"), c=s.is_kws(d=s.is_int(help="Help for d")), ), ))) schema.help() help_calls = m_print_help.normalized_calls() help_calls = [{h["key"]: h["help"]} for h in help_calls] assert help_calls == [ { "root": None }, { "a": "Help for a" }, { "b": "Help for b" }, { "c": None }, { "d": "Help for d" }, ]
def it_fetches_user_data(): schema = s( s.is_dict( help="Help for a", elems=dict( b=s.is_int(help="Help for b", userdata="userdata_1"), c=s.is_kws(d=s.is_int(help="Help for d")), ), )) tlf = schema.top_level_fields() assert tlf[0][0] == "b" and tlf[0][3] == "userdata_1"
class Priors: """ See above docs """ priors_desc_schema = s.is_dict( elems=dict( class_name=s.is_str(), hyper_params=s.is_dict(), params=s.is_dict(required=False), ), required=False, ) def add(self, name, prior, source=None): check.t(prior, Prior) if source is not None: prior.source = source assert prior.source is not None self.priors[name] = prior def delete_ch_specific_records(self): remove_keys = [] for key in self.priors.keys(): parts = key.split(".") if len(parts) > 1: if parts[1].startswith("ch_"): remove_keys += [key] self.priors = { key: val for key, val in self.priors.items() if key not in remove_keys } def _instanciate_prior( self, source, name, class_name, hyper_params, params, overwrite=False ): """ ADD a prior instance of class_name from source into name """ parts = class_name.split(".") if len(parts) == 1: # Use this module module = sys.modules[__name__] else: module = importlib.import_module(".".join(parts[0:-1])) klass = getattr(module, parts[-1]) instance = klass(**(hyper_params or {})) if isinstance(instance, PriorsIncludeFile): # RECURSE for include file self._instanciate_priors_desc( f"Include file '{instance.path}'", instance.priors_desc ) else: instance.source = source instance.deserialize(params) if name in self.priors and not overwrite: raise ValueError(f"Duplicate prior name '{name}'") self.add(name, instance) def _instanciate_priors_desc(self, source, priors_desc): for name, desc in priors_desc.items(): self._instanciate_prior( source, name, desc["class_name"], desc.get("hyper_params"), desc.get("params"), ) @classmethod def from_priors_desc(cls, source, priors_desc): """ Create Priors from a desc, for example from a task parameters block These can include "PriorsIncludeFile" which recursively add """ priors = Priors() priors._instanciate_priors_desc(source, priors_desc) return priors def __init__(self, hyper_im_mea=None, hyper_n_channels=None): self._default_reg_illum = None self._default_reg_psf = None self._default_ch_aln = None if hyper_im_mea is not None: self._default_reg_illum = RegIllumPrior(hyper_im_mea).set( hyper_im_mea / 2, hyper_im_mea / 2, 0.0 ) hyper_n_divs = 5 sigma_x = np.full((hyper_n_divs, hyper_n_divs), 1.8) sigma_y = np.full((hyper_n_divs, hyper_n_divs), 1.8) rho = np.full((hyper_n_divs, hyper_n_divs), 0.0) self._default_reg_psf = RegPSFPrior( hyper_im_mea, hyper_peak_mea=default_hyper_peak_mea, hyper_n_divs=hyper_n_divs, ).set(sigma_x=sigma_x, sigma_y=sigma_y, rho=rho) if hyper_n_channels is not None: self._default_ch_aln = ChannelAlignPrior(hyper_n_channels).set( np.zeros((hyper_n_channels, 2)) ) self._defaults = { "reg_illum": self._default_reg_illum, "reg_psf": self._default_reg_psf, "ch_aln": self._default_ch_aln, "gain_mu": MLEPrior().set(value=7500.0), "gain_sigma": MLEPrior().set(value=0.0), "bg_mu": MLEPrior().set(value=0.0), "bg_sigma": MLEPrior().set(value=100.0), "row_k_sigma": MLEPrior().set(value=0.15), "p_edman_failure": MLEPrior().set(value=0.06), "p_detach": MLEPrior().set(value=0.05), "p_bleach": MLEPrior().set(value=0.05), "p_non_fluorescent": MLEPrior().set(value=0.07), } self.priors = {} def get_exact(self, request_name): """ Look up exact match or return None """ if request_name in self.priors: return Munch( request_name=request_name, matched_name=request_name, prior=self.priors[request_name], ) return None def get(self, request_name): """ Search for a request_name looking up the naming tree for a match if needed """ parts = request_name.split(".") for i in range(len(parts), 0, -1): # SEARCH up the hierarchy matched_name = ".".join(parts[0:i]) if matched_name in self.priors: return Munch( request_name=request_name, matched_name=matched_name, prior=self.priors[matched_name], ) # Not found, look in defaults matched_name = parts[0] if matched_name in self._defaults: default = self._defaults[parts[0]] if default is not None: default.source = "defaults" return Munch( request_name=request_name, matched_name=matched_name, prior=default, ) raise KeyError(f"Prior '{request_name}' not resolved") def enumerate_names(self): return list(self.priors.keys()) def remove(self, name): del self.priors[name] def update(self, other, source): for key, value in other.priors.items(): value = copy.deepcopy(value) value.source = source self.priors[key] = value def get_distr(self, request_name): """ Like get() but returns the distribution object only (without the other metadata) """ p = self.get(request_name) return p.prior def get_sample(self, request_name): """ Like get() but returns a sample from the distribution """ p = self.get(request_name) return p.prior.sample() def get_mle(self, request_name, **kwargs): """ Like get() but returns the MLE. For now, this means that the prior must come from an MLEPrior """ p = self.get(request_name) assert isinstance(p.prior, MLEPrior) return p.prior.sample() def helper_illum_model(self, n_channels): """ Used by nn_v2 to load the C context with cols: gain_mu, gain_sigma, bg_mu, bg_sigma One row per channel. Note that this is still providing the older functionality of returning a MLE value for the parameters. Eventually this is going to be changed so that the nn_v2 will have a parametric description of the the prior so that it can draw samples instead of using the MLE. """ illum_model = np.zeros((n_channels, 4)) for ch_i in range(n_channels): prior = self.get(f"gain_mu.ch_{ch_i}") assert isinstance(prior.prior, MLEPrior) gain_mu = prior.prior.sample() illum_model[ch_i, 0] = gain_mu prior = self.get(f"gain_sigma.ch_{ch_i}") assert isinstance(prior.prior, MLEPrior) gain_sigma = prior.prior.sample() illum_model[ch_i, 1] = gain_sigma illum_model[ch_i, 2] = 0.0 prior = self.get(f"bg_sigma.ch_{ch_i}") assert isinstance(prior.prior, MLEPrior) bg_sigma = prior.prior.sample() illum_model[ch_i, 3] = bg_sigma return illum_model @classmethod def copy(cls, src): self = cls() self.priors = copy.deepcopy(src.priors) return self
def it_checks_no_extra(): test_s = s( s.is_dict(elems=dict(a=s.is_int(), b=s.is_int()), no_extras=True)) with zest.raises(SchemaValidationFailed): test_s.validate(dict(a=1, b=1, c=1))
def it_allows_elems_in_dict(): s(s.is_dict(dict(a=s.is_int(noneable=True))))
def it_does_not_overwrite_an_existing_dict(): test_s = s(s.is_kws_r(a=s.is_dict())) test = dict(a=dict(b=1)) # a has a good value, do not overwrite it test_s.apply_defaults(defaults=dict(a=None), apply_to=test) assert test["a"]["b"] == 1
def it_all_required_false_by_default(): test_s = s(s.is_dict(elems=dict(a=s.is_int(), b=s.is_int()))) test_s.validate(dict(a=1))
def it_validates_default_dict(): test_s = s(s.is_dict()) test_s.validate({}) test_s.validate(dict(a=1, b=2)) with zest.raises(SchemaValidationFailed): test_s.validate(1)
class SimV1Params(ParamsAndPriors): """ Simulations parameters is and ErrorModel + parameters for sim """ defaults = Munch( n_pres=1, n_mocks=0, n_edmans=1, n_samples_train=5_000, n_samples_test=1_000, dyes=[], labels=[], random_seed=None, train_n_sample_multiplier= None, # This does not appear to be used anywhere. tfb allow_train_test_to_be_identical=False, enable_ptm_labels=False, is_survey=False, ) schema = s( s.is_kws_r( is_survey=s.is_bool(), priors_desc=Priors.priors_desc_schema, n_pres=s.is_int(bounds=(0, None)), n_mocks=s.is_int(bounds=(0, None)), n_edmans=s.is_int(bounds=(0, None)), n_samples_train=s.is_int(bounds=(1, None)), n_samples_test=s.is_int(bounds=(1, None)), dyes=s.is_list(elems=s.is_kws_r(dye_name=s.is_str(), channel_name=s.is_str())), labels=s.is_list(elems=s.is_kws_r( aa=s.is_str(), dye_name=s.is_str(), label_name=s.is_str(), ptm_only=s.is_bool(required=False, noneable=True), )), channels=s.is_dict(required=False), random_seed=s.is_int(required=False, noneable=True), allow_train_test_to_be_identical=s.is_bool(required=False, noneable=True), enable_ptm_labels=s.is_bool(required=False, noneable=True), )) # def copy(self): # # REMOVE everything that _build_join_dfs put in # utils.safe_del(self, "df") # utils.safe_del(self, "by_channel") # utils.safe_del(self, "ch_by_aa") # # dst = utils.munch_deep_copy(self, klass_set={SimV1Params}) # dst.error_model = ErrorModel(**dst.error_model) # assert isinstance(dst, SimV1Params) # return dst def __init__(self, **kwargs): super().__init__(source="SimV1Params", **kwargs) self._setup_dfs() def validate(self): super().validate() all_dye_names = list(set([d.dye_name for d in self.dyes])) # No duplicate dye names self._validate( len(all_dye_names) == len(self.dyes), "The dye list contains a duplicate") # No duplicate labels self._validate( len(list(set(utils.listi(self.labels, "aa")))) == len(self.labels), "There is a duplicate label", ) # All labels have a legit dye name [ self._validate( label.dye_name in all_dye_names, f"Label {label.label_name} does not have a valid matching dye_name", ) for label in self.labels ] # Channel mappings mentioned_channels = {dye.channel_name: False for dye in self.dyes} if "channels" in self: # Validate that channel mapping is complete for channel_name, ch_i in self.channels.items(): self._validate( channel_name in mentioned_channels, f"Channel name '{channel_name}' was not found in dyes", ) mentioned_channels[channel_name] = True self._validate( all([mentioned for _, mentioned in mentioned_channels.items()]), "Not all channels in dyes were enumerated in channels", ) else: # No channel mapping: assign them self["channels"] = { ch_name: i for i, ch_name in enumerate(sorted(mentioned_channels.keys())) } @property def n_cycles(self): return self.n_pres + self.n_mocks + self.n_edmans def channel_names(self): return sorted(list(set(utils.listi(self.dyes, "channel_name")))) def channel_i_by_name(self): channels = self.channel_names() return { channel_name: channel_i for channel_i, channel_name in enumerate(channels) } @property def n_channels(self): return len(self.channel_i_by_name().keys()) @property def n_channels_and_cycles(self): return self.n_channels, self.n_cycles def _setup_dfs(self): """ The error model contains information about the dyes and labels and other terms. Those error model parameters are wired together by names which are useful for reconciling calibrations. But here, these "by name" parameters are all put into a dataframe so that they can be indexed by integers. """ dyes_df = pd.DataFrame(self.dyes) assert len(dyes_df) > 0 labels_df = pd.DataFrame(self.labels) assert len(labels_df) > 0 # LOOKUP dye priors dye_priors = [] for dye in self.dyes: # SEARCH priors by dye name and if not found by channel p_non_fluorescent = self.priors.get_exact( f"p_non_fluorescent.{dye.dye_name}") if p_non_fluorescent is None: p_non_fluorescent = self.priors.get( f"p_non_fluorescent.ch_{dye.channel_name}") dye_priors += [ Munch( dye_name=dye.dye_name, p_non_fluorescent=p_non_fluorescent.prior, ) ] dye_priors_df = pd.DataFrame(dye_priors) # dye_priors_df: (dye_name, p_non_fluorescent) dyes_df = utils.easy_join(dyes_df, dye_priors_df, "dye_name") # dyes_df: (dye_name, channel_name, p_non_fluorescent) # TODO: LOOKUP label priors # (p_failure_to_bind_aa, p_failure_to_attach_to_dye) # LOOKUP channel priors ch_priors = pd.DataFrame([ dict( channel_name=channel_name, ch_i=ch_i, bg_mu=self.priors.get(f"bg_mu.ch_{ch_i}").prior, bg_sigma=self.priors.get(f"bg_sigma.ch_{ch_i}").prior, gain_mu=self.priors.get(f"gain_mu.ch_{ch_i}").prior, gain_sigma=self.priors.get(f"gain_sigma.ch_{ch_i}").prior, row_k_sigma=self.priors.get(f"row_k_sigma.ch_{ch_i}").prior, p_bleach=self.priors.get(f"p_bleach.ch_{ch_i}").prior, ) for channel_name, ch_i in self.channels.items() ]) # ch_priors: (channel_name, ch_i, ...) self._channel__priors = (utils.easy_join( dyes_df, ch_priors, "channel_name").drop( columns=["p_non_fluorescent"]).drop_duplicates().reset_index()) # self._channel__priors: ( # 'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name', # 'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma', # ) # SANITY check channel__priors group_by_ch = self._channel__priors.groupby("ch_i") for field in ( "bg_mu", "bg_sigma", "gain_mu", "gain_sigma", "row_k_sigma", ): assert np.all(group_by_ch[field].nunique() == 1) assert "p_non_fluorescent" not in self._channel__priors.columns labels_dyes_df = utils.easy_join(labels_df, dyes_df, "dye_name") self._dye__label__priors = utils.easy_join( labels_dyes_df, ch_priors, "channel_name").reset_index(drop=True) # self._dye__label__priors: ( # 'channel_name', 'dye_name', 'aa', 'label_name', # 'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma', # 'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach' # ) self._ch_by_aa = { row.aa: row.ch_i for row in self._dye__label__priors.itertuples() } def dye__label__priors(self): """ DataFrame( 'channel_name', 'dye_name', 'aa', 'label_name', 'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma', 'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach' ) """ return self._dye__label__priors def channel__priors(self): """ DataFrame( 'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name', 'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma', ) """ return self._channel__priors def by_channel(self): return self._channel__priors.set_index("ch_i") def to_label_list(self): """Summarize labels like: ["DE", "C"]""" return [ "".join([ label.aa for label in self.labels if label.dye_name == dye.dye_name ]) for dye in self.dyes ] def to_label_str(self): """Summarize labels like: DE,C""" return ",".join(self.to_label_list()) @classmethod def construct_from_aa_list(cls, aa_list, **kwargs): """ This is a helper to generate channel when you have a list of aas. For example, two channels where ch0 is D&E and ch1 is Y. ["DE", "Y"]. If you pass in an error model, it needs to match channels and labels. """ check.list_or_tuple_t(aa_list, str) allowed_aa_mods = ["[", "]"] assert all([(aa.isalpha() or aa in allowed_aa_mods) for aas in aa_list for aa in list(aas)]) dyes = [ Munch(dye_name=f"dye_{ch}", channel_name=f"ch_{ch}") for ch, _ in enumerate(aa_list) ] # Note the extra for loop because "DE" needs to be split into "D" & "E" # which is done by aa_str_to_list() - which also handles PTMs like S[p] labels = [ Munch( aa=aa, dye_name=f"dye_{ch}", label_name=f"label_{ch}", ptm_only=False, ) for ch, aas in enumerate(aa_list) for aa in aa_str_to_list(aas) ] return cls(dyes=dyes, labels=labels, **kwargs)
class SigprocV1Params(Params): defaults = dict( hat_rad=2, iqr_rng=96, threshold_abs=1.0, channel_indices_for_alignment=None, channel_indices_for_peak_finding=None, radiometry_channels=None, save_debug=False, peak_find_n_cycles=4, peak_find_start=0, radial_filter=None, anomaly_iqr_cutoff=95, n_fields_limit=None, save_full_signal_radmat_npy=False, ) schema = s( s.is_kws_r( anomaly_iqr_cutoff=s.is_number(noneable=True, bounds=(0, 100)), radial_filter=s.is_float(noneable=True, bounds=(0, 1)), peak_find_n_cycles=s.is_int(bounds=(1, None), noneable=True), peak_find_start=s.is_int(bounds=(0, None), noneable=True), save_debug=s.is_bool(), hat_rad=s.is_int(bounds=(1, 3)), iqr_rng=s.is_number(noneable=True, bounds=(0, 100)), threshold_abs=s.is_number( bounds=(0, 100)), # Not sure of a reasonable bound channel_indices_for_alignment=s.is_list(s.is_int(), noneable=True), channel_indices_for_peak_finding=s.is_list(s.is_int(), noneable=True), radiometry_channels=s.is_dict(noneable=True), n_fields_limit=s.is_int(noneable=True), save_full_signal_radmat_npy=s.is_bool(), )) def validate(self): # Note: does not call super because the override_nones is set to false here self.schema.apply_defaults(self.defaults, apply_to=self, override_nones=False) self.schema.validate(self, context=self.__class__.__name__) if self.radiometry_channels is not None: pat = re.compile(r"[0-9a-z_]+") for name, channel_i in self.radiometry_channels.items(): self._validate( pat.fullmatch(name), "radiometry_channels name must be lower-case alphanumeric (including underscore)", ) self._validate(isinstance(channel_i, int), "channel_i must be an integer") def set_radiometry_channels_from_input_channels_if_needed( self, n_channels): if self.radiometry_channels is None: # Assume channels from nd2 manifest channels = list(range(n_channels)) self.radiometry_channels = {f"ch_{ch}": ch for ch in channels} @property def n_output_channels(self): return len(self.radiometry_channels.keys()) @property def n_input_channels(self): return len(self.radiometry_channels.keys()) @property def channels_cycles_dim(self): # This is a cache set in sigproc_v1. # It is a helper for the repeative call: # n_outchannels, n_inchannels, n_cycles, dim = return self._outchannels_inchannels_cycles_dim def _input_channels(self): """ Return a list that converts channel number of the output to the channel of the input Example: input might have channels ["foo", "bar"] the radiometry_channels has: {"bar": 0}] Thus this function returns [1] because the 0th output channel is mapped to the "1" input channel """ return [ self.radiometry_channels[name] for name in sorted(self.radiometry_channels.keys()) ] # def input_names(self): # return sorted(self.radiometry_channels.keys()) def output_channel_to_input_channel(self, out_ch): return self._input_channels()[out_ch] def input_channel_to_output_channel(self, in_ch): """Not every input channel necessarily has an output; can return None""" return utils.filt_first_arg(self._input_channels(), lambda x: x == in_ch)
def it_fetches_list_elem_type(): schema = s(s.is_dict(elems=dict(a=s.is_list(s.is_int())))) tlf = schema.top_level_fields() assert tlf[0][0] == "a" and tlf[0][4] is int
def it_checks_no_extra_flase_by_default(): test_s = s(s.is_dict(elems=dict(a=s.is_int(), b=s.is_int()))) test_s.validate(dict(a=1, c=1))
def it_allows_all_required_to_be_overriden(): test_s = s( s.is_dict(elems=dict(a=s.is_int(required=False)), all_required=True)) test_s.validate(dict())
def it_checks_key_type_str(): test_s = s(s.is_dict(elems={1: s.is_int()})) with zest.raises(SchemaValidationFailed): test_s.validate({1: 2})
def it_checks_required(): test_s = s( s.is_dict( elems=dict(a=s.is_int(required=True), b=s.is_int()))) with zest.raises(SchemaValidationFailed): test_s.validate(dict(b=1))
class SigprocV2Params(Params): defaults = dict( radiometry_channels=None, n_fields_limit=None, save_full_signal_radmat_npy=False, # use_cycle_zero_psfs_only=False, ) schema = s( s.is_kws_r( radiometry_channels=s.is_dict(noneable=True), n_fields_limit=s.is_int(noneable=True), save_full_signal_radmat_npy=s.is_bool(), calibration=s.is_dict(), instrument_subject_id=s.is_str(noneable=True), # use_cycle_zero_psfs_only=s.is_bool(), )) def validate(self): # Note: does not call super because the override_nones is set to false here self.schema.apply_defaults(self.defaults, apply_to=self, override_nones=False) self.schema.validate(self, context=self.__class__.__name__) self.calibration = Calibration(self.calibration) if self.instrument_subject_id is not None: self.calibration.filter_subject_ids(self.instrument_subject_id) if len(self.calibration.keys()) == 0: raise ValueError( f"All calibration records removed after filter_subject_ids on subject_id '{self.instrument_subject_id}'" ) assert not self.calibration.has_subject_ids() if self.radiometry_channels is not None: pat = re.compile(r"[0-9a-z_]+") for name, channel_i in self.radiometry_channels.items(): self._validate( pat.fullmatch(name), "radiometry_channels name must be lower-case alphanumeric (including underscore)", ) self._validate(isinstance(channel_i, int), "channel_i must be an integer") def set_radiometry_channels_from_input_channels_if_needed( self, n_channels): if self.radiometry_channels is None: # Assume channels from nd2 manifest channels = list(range(n_channels)) self.radiometry_channels = {f"ch_{ch}": ch for ch in channels} @property def n_output_channels(self): return len(self.radiometry_channels.keys()) @property def n_input_channels(self): return len(self.radiometry_channels.keys()) # @property # def channels_cycles_dim(self): # # This is a cache set in sigproc_v1. # # It is a helper for the repetitive call: # # n_outchannels, n_inchannels, n_cycles, dim = # return self._outchannels_inchannels_cycles_dim def _input_channels(self): """ Return a list that converts channel number of the output to the channel of the input Example: input might have channels ["foo", "bar"] the radiometry_channels has: {"bar": 0}] Thus this function returns [1] because the 0th output channel is mapped to the "1" input channel """ return [ self.radiometry_channels[name] for name in sorted(self.radiometry_channels.keys()) ] # def input_names(self): # return sorted(self.radiometry_channels.keys()) def output_channel_to_input_channel(self, out_ch): return self._input_channels()[out_ch] def input_channel_to_output_channel(self, in_ch): """Not every input channel necessarily has an output; can return None""" return utils.filt_first_arg(self._input_channels(), lambda x: x == in_ch)
class SimV2Params(ParamsAndPriors): # The following constants are repeated in sim_v2.h because it # is hard to get constants like this to be shared between # the two languages. This shouldn't be a problem as they are stable. # TODO: Move these to an import form the pyx CycleKindType = np.uint8 CYCLE_TYPE_PRE = 0 CYCLE_TYPE_MOCK = 1 CYCLE_TYPE_EDMAN = 2 channel__priors__columns = ( "ch_i", "channel_name", "bg_mu", "bg_sigma", "dye_name", "gain_mu", "gain_sigma", "index", "p_bleach", "row_k_sigma", ) dye__label__priors__columns = ( "channel_name", "dye_name", "aa", "label_name", "ptm_only", "p_non_fluorescent", "ch_i", "bg_mu", "bg_sigma", "gain_mu", "gain_sigma", "row_k_sigma", "p_bleach", ) defaults = Munch( n_pres=1, n_mocks=0, n_edmans=1, n_samples_train=5_000, n_samples_test=1_000, dyes=[], labels=[], random_seed=None, allow_train_test_to_be_identical=False, allow_edman_cterm=False, enable_ptm_labels=False, is_survey=False, train_includes_radmat=False, test_includes_dyemat=False, dump_debug=False, generate_flus=True, use_lognormal_model=False, ) schema = s( s.is_kws_r( priors_desc=Priors.priors_desc_schema, is_survey=s.is_bool(), n_pres=s.is_int(bounds=(0, None)), n_mocks=s.is_int(bounds=(0, None)), n_edmans=s.is_int(bounds=(0, None)), n_samples_train=s.is_int(bounds=(1, None)), n_samples_test=s.is_int(bounds=(1, None)), dyes=s.is_list(elems=s.is_kws_r( dye_name=s.is_str(), channel_name=s.is_str(), )), labels=s.is_list(elems=s.is_kws_r( aa=s.is_str(), dye_name=s.is_str(), label_name=s.is_str(), ptm_only=s.is_bool(required=False, noneable=True), )), channels=s.is_dict(required=False), random_seed=s.is_int(required=False, noneable=True), allow_train_test_to_be_identical=s.is_bool(required=False, noneable=True), allow_edman_cterm=s.is_bool(required=False, noneable=True), enable_ptm_labels=s.is_bool(required=False, noneable=True), train_includes_radmat=s.is_bool(required=False, noneable=True), test_includes_dyemat=s.is_bool(required=False, noneable=True), dump_debug=s.is_bool(), generate_flus=s.is_bool(), use_lognormal_model=s.is_bool(), )) # def copy(self): # dst = utils.munch_deep_copy(self, klass_set={SimV2Params}) # assert isinstance(dst, SimV2Params) # return dst def __init__(self, **kwargs): # _skip_setup_dfs is True in fixture mode super().__init__(source="SimV2Params", **kwargs) self._setup_dfs() def validate(self): super().validate() all_dye_names = list(set([d.dye_name for d in self.dyes])) # No duplicate dye names self._validate( len(all_dye_names) == len(self.dyes), "The dye list contains a duplicate") # No duplicate labels self._validate( len(list(set(utils.listi(self.labels, "aa")))) == len(self.labels), "There is a duplicate label in the label_set", ) # All labels have a legit dye name [ self._validate( label.dye_name in all_dye_names, f"Label {label.label_name} does not have a valid matching dye_name", ) for label in self.labels ] # Channel mappings mentioned_channels = {dye.channel_name: False for dye in self.dyes} if "channels" in self: # Validate that channel mapping is complete for channel_name, ch_i in self.channels.items(): self._validate( channel_name in mentioned_channels, f"Channel name '{channel_name}' was not found in dyes", ) mentioned_channels[channel_name] = True self._validate( all([mentioned for _, mentioned in mentioned_channels.items()]), "Not all channels in dyes were enumerated in channels", ) else: # No channel mapping: assign them self["channels"] = { ch_name: i for i, ch_name in enumerate(sorted(mentioned_channels.keys())) } @property def n_cycles(self): return self.n_pres + self.n_mocks + self.n_edmans def channel_names(self): return [ ch_name for ch_name, _ in sorted(self.channels.items(), key=lambda item: item[1]) ] def ch_i_by_name(self): return self.channels @property def n_channels(self): # if self.is_photobleaching_run: # return 1 return len(self.channels) @property def n_channels_and_cycles(self): return self.n_channels, self.n_cycles def _setup_dfs(self): """ Assemble all of the priors into several dataframes indexed differently. (Call after validate) * self.channel__priors: ch_i, ch_name, bg_mu, bg_sigma, gain_mu, gain_sigma, row_k_sigma, p_bleach --> Note, does NOT have p_non_fluorescent because this is a dye property * self.dye__label__priors: aa, label_name, dye_name, ch_i, ch_name, bg_mu, bg_sigma, gain_mu, gain_sigma, row_k_sigma, p_bleach p_non_fluorescent, """ # if self.is_photobleaching_run: # # Not sure what these should be yet # # self._ch_by_aa = {} # # self._channel__priors = pd.DataFrame(columns=self.channel__priors__columns) # # self._dye__label__priors = pd.DataFrame(columns=self.dye__label__priors__columns) # self.dyes = [Munch(dye_name="zero", channel_name="zero")] # self.channels = Munch(zero=0) # self.labels = [ # dict(aa=".", dye_name="zero", label_name="zero", ptm_only=False) # ] labels_df = pd.DataFrame(self.labels) # labels_df: (aa, dye_name, label_name, ptm_only) # assert len(labels_df) > 0 dyes_df = pd.DataFrame(self.dyes) # dyes_df: (dye_name, channel_name) # assert len(dyes_df) > 0 # LOOKUP dye priors dye_priors = [] for dye in self.dyes: # SEARCH priors by dye name and if not found by channel p_non_fluorescent = self.priors.get_exact( f"p_non_fluorescent.{dye.dye_name}") if p_non_fluorescent is None: p_non_fluorescent = self.priors.get( f"p_non_fluorescent.ch_{dye.channel_name}") dye_priors += [ Munch( dye_name=dye.dye_name, p_non_fluorescent=p_non_fluorescent.prior, ) ] dye_priors_df = pd.DataFrame(dye_priors) # dye_priors_df: (dye_name, p_non_fluorescent) dyes_df = utils.easy_join(dyes_df, dye_priors_df, "dye_name") # dyes_df: (dye_name, channel_name, p_non_fluorescent) # TODO: LOOKUP label priors # (p_failure_to_bind_aa, p_failure_to_attach_to_dye) # LOOKUP channel priors ch_priors = pd.DataFrame([ dict( channel_name=channel_name, ch_i=ch_i, bg_mu=self.priors.get(f"bg_mu.ch_{ch_i}").prior, bg_sigma=self.priors.get(f"bg_sigma.ch_{ch_i}").prior, gain_mu=self.priors.get(f"gain_mu.ch_{ch_i}").prior, gain_sigma=self.priors.get(f"gain_sigma.ch_{ch_i}").prior, row_k_sigma=self.priors.get(f"row_k_sigma.ch_{ch_i}").prior, p_bleach=self.priors.get(f"p_bleach.ch_{ch_i}").prior, ) for channel_name, ch_i in self.channels.items() ]) # ch_priors: (channel_name, ch_i, ...) self._channel__priors = (utils.easy_join( dyes_df, ch_priors, "channel_name").drop( columns=["p_non_fluorescent"]).drop_duplicates().reset_index()) # self._channel__priors: ( # 'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name', # 'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma', # ) # SANITY check channel__priors group_by_ch = self._channel__priors.groupby("ch_i") for field in ( "bg_mu", "bg_sigma", "gain_mu", "gain_sigma", "row_k_sigma", ): assert np.all(group_by_ch[field].nunique() == 1) assert "p_non_fluorescent" not in self._channel__priors.columns labels_dyes_df = utils.easy_join(labels_df, dyes_df, "dye_name") self._dye__label__priors = utils.easy_join( labels_dyes_df, ch_priors, "channel_name").reset_index(drop=True) # self._dye__label__priors: ( # 'channel_name', 'dye_name', 'aa', 'label_name', # 'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma', # 'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach' # ) self._ch_by_aa = { row.aa: row.ch_i for row in self._dye__label__priors.itertuples() } def ch_by_aa(self): return self._ch_by_aa def dye__label__priors(self): """ DataFrame( 'channel_name', 'dye_name', 'aa', 'label_name', 'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma', 'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach' ) """ return self._dye__label__priors def channel__priors(self): """ DataFrame( 'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name', 'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma', ) """ return self._channel__priors def by_channel(self): return self._channel__priors.set_index("ch_i") def to_label_list(self): """Summarize labels like: ["DE", "C"]""" return [ "".join([ label.aa for label in self.labels if label.dye_name == dye.dye_name ]) for dye in self.dyes ] def to_label_str(self): """Summarize labels like: DE,C""" return ",".join(self.to_label_list()) def cycles_array(self): cycles = np.zeros((self.n_cycles, ), dtype=self.CycleKindType) i = 0 for _ in range(self.n_pres): cycles[i] = self.CYCLE_TYPE_PRE i += 1 for _ in range(self.n_mocks): cycles[i] = self.CYCLE_TYPE_MOCK i += 1 for _ in range(self.n_edmans): cycles[i] = self.CYCLE_TYPE_EDMAN i += 1 return cycles def pcbs(self, pep_seq_df): """ pcb stands for (p)ep_i, (c)hannel_i, (b)right_probability This is a structure that is liek a "flu" but with an extra bright probability. Each peptide has a row for each amino acid That row has a columns (pep_i, ch_i, p_bright) And it will have np.nan for ch_i and p_bright **IF THERE IS NO LABEL** bright_probability is the inverse of all the ways a dye can fail to be visible ie the probability that a dye is active. pep_seq_df: Any DataFrame with an "aa" column Returns: contiguous ndarray(:, 3) where there 3 columns are: pep_i, ch_i, p_bright """ labelled_pep_df = pep_seq_df.join( self.dye__label__priors().set_index("aa"), on="aa", how="left") # p_bright = is the product of (1.0 - ) all the ways the dye can fail to be visible. labelled_pep_df["p_bright"] = ( # TODO: Sim needs to be converted to use priors sampling # at which point this function needs to be refactored # so that the parameters of the priors can be sampled in C. 1.0 - np.array([ i.sample() if isinstance(i, Prior) else np.nan for i in labelled_pep_df.p_non_fluorescent ]) # TODO: Add label priors # * (1.0 - labelled_pep_df.p_failure_to_attach_to_dye) # * (1.0 - labelled_pep_df.p_failure_to_bind_aa) ) labelled_pep_df.sort_values(by=["pep_i", "pep_offset_in_pro"], inplace=True) return np.ascontiguousarray( labelled_pep_df[["pep_i", "ch_i", "p_bright"]].values) @classmethod def from_aa_list_fixture(cls, aa_list, priors=None, **kwargs): """ This is a helper to generate channel when you have a list of aas. For example, two channels where ch0 is D&E and ch1 is Y. ["DE", "Y"]. """ check.list_or_tuple_t(aa_list, str) allowed_aa_mods = ["[", "]"] assert all([(aa.isalpha() or aa in allowed_aa_mods) for aas in aa_list for aa in list(aas)]) dyes = [ Munch(dye_name=f"dye_{ch}", channel_name=f"ch_{ch}") for ch, _ in enumerate(aa_list) ] # Note the extra for loop because "DE" needs to be split into "D" & "E" # which is done by aa_str_to_list() - which also handles PTMs like S[p] labels = [ Munch( aa=aa, dye_name=f"dye_{ch}", label_name=f"label_{ch}", ptm_only=False, ) for ch, aas in enumerate(aa_list) for aa in aa_str_to_list(aas) ] return cls(dyes=dyes, labels=labels, priors=priors, **kwargs)