Exemplo n.º 1
0
        def it_returns_required_elems():
            userdata = dict(some_key=1)

            test_s = s(
                s.is_dict(
                    all_required=True,
                    elems=dict(
                        a=s.is_int(),
                        b=s.is_float(help="A float"),
                        c=s.is_number(),
                        d=s.is_str(userdata=userdata),
                        e=s.is_list(),
                        f=s.is_dict(all_required=True,
                                    elems=dict(d=s.is_int(), e=s.is_int())),
                    ),
                ))
            reqs = test_s.requirements()
            assert reqs == [
                ("a", int, None, None),
                ("b", float, "A float", None),
                ("c", float, None, None),
                ("d", str, None, userdata),
                ("e", list, None, None),
                ("f", dict, None, None),
            ]
Exemplo n.º 2
0
 def _before():
     nonlocal test_s
     test_s = s(
         s.is_dict(
             all_required=True,
             elems=dict(
                 a=s.is_int(),
                 b=s.is_int(),
                 c=s.is_dict(all_required=True,
                             elems=dict(d=s.is_int(), e=s.is_int())),
             ),
         ))
Exemplo n.º 3
0
 def it_validates_recursively():
     test_s = s(
         s.is_dict(elems=dict(
             a=s.is_int(),
             b=s.is_list(required=True, elems=s.is_str()),
             c=s.is_dict(required=True),
         )))
     test_s.validate(dict(a=1, b=["a", "b"], c=dict()))
     with zest.raises(SchemaValidationFailed):
         test_s.validate(dict(a=1, b=[1], c=dict()))
     with zest.raises(SchemaValidationFailed):
         test_s.validate(dict(a=1, b=["a"], c=1))
Exemplo n.º 4
0
            def it_ignores_underscored_keys():
                test_s = s(
                    s.is_dict(elems=dict(a=s.is_int(), b=s.is_int()),
                              no_extras=True))
                with zest.raises(SchemaValidationFailed):
                    test_s.validate(dict(a=1, b="str"))
                with zest.raises(SchemaValidationFailed):
                    test_s.validate(dict(a=1, b=2, _c=[]))

                test_s = s(
                    s.is_dict(
                        elems=dict(a=s.is_int(), b=s.is_int()),
                        no_extras=True,
                        ignore_underscored_keys=True,
                    ))
                test_s.validate(dict(a=1, b=2, _c=[]))
Exemplo n.º 5
0
    def it_shows_help():
        schema = s(
            s.is_kws(a=s.is_dict(
                help="Help for a",
                elems=dict(
                    b=s.is_int(help="Help for b"),
                    c=s.is_kws(d=s.is_int(help="Help for d")),
                ),
            )))
        schema.help()

        help_calls = m_print_help.normalized_calls()
        help_calls = [{h["key"]: h["help"]} for h in help_calls]
        assert help_calls == [
            {
                "root": None
            },
            {
                "a": "Help for a"
            },
            {
                "b": "Help for b"
            },
            {
                "c": None
            },
            {
                "d": "Help for d"
            },
        ]
Exemplo n.º 6
0
 def it_fetches_user_data():
     schema = s(
         s.is_dict(
             help="Help for a",
             elems=dict(
                 b=s.is_int(help="Help for b", userdata="userdata_1"),
                 c=s.is_kws(d=s.is_int(help="Help for d")),
             ),
         ))
     tlf = schema.top_level_fields()
     assert tlf[0][0] == "b" and tlf[0][3] == "userdata_1"
Exemplo n.º 7
0
class Priors:
    """
    See above docs
    """

    priors_desc_schema = s.is_dict(
        elems=dict(
            class_name=s.is_str(),
            hyper_params=s.is_dict(),
            params=s.is_dict(required=False),
        ),
        required=False,
    )

    def add(self, name, prior, source=None):
        check.t(prior, Prior)
        if source is not None:
            prior.source = source
        assert prior.source is not None
        self.priors[name] = prior

    def delete_ch_specific_records(self):
        remove_keys = []
        for key in self.priors.keys():
            parts = key.split(".")
            if len(parts) > 1:
                if parts[1].startswith("ch_"):
                    remove_keys += [key]
        self.priors = {
            key: val for key, val in self.priors.items() if key not in remove_keys
        }

    def _instanciate_prior(
        self, source, name, class_name, hyper_params, params, overwrite=False
    ):
        """
        ADD a prior instance of class_name from source into name
        """
        parts = class_name.split(".")
        if len(parts) == 1:
            # Use this module
            module = sys.modules[__name__]
        else:
            module = importlib.import_module(".".join(parts[0:-1]))

        klass = getattr(module, parts[-1])
        instance = klass(**(hyper_params or {}))
        if isinstance(instance, PriorsIncludeFile):
            # RECURSE for include file
            self._instanciate_priors_desc(
                f"Include file '{instance.path}'", instance.priors_desc
            )
        else:
            instance.source = source
            instance.deserialize(params)
            if name in self.priors and not overwrite:
                raise ValueError(f"Duplicate prior name '{name}'")
            self.add(name, instance)

    def _instanciate_priors_desc(self, source, priors_desc):
        for name, desc in priors_desc.items():
            self._instanciate_prior(
                source,
                name,
                desc["class_name"],
                desc.get("hyper_params"),
                desc.get("params"),
            )

    @classmethod
    def from_priors_desc(cls, source, priors_desc):
        """
        Create Priors from a desc, for example from a task parameters block
        These can include "PriorsIncludeFile" which recursively add
        """
        priors = Priors()
        priors._instanciate_priors_desc(source, priors_desc)
        return priors

    def __init__(self, hyper_im_mea=None, hyper_n_channels=None):
        self._default_reg_illum = None
        self._default_reg_psf = None
        self._default_ch_aln = None

        if hyper_im_mea is not None:
            self._default_reg_illum = RegIllumPrior(hyper_im_mea).set(
                hyper_im_mea / 2, hyper_im_mea / 2, 0.0
            )

            hyper_n_divs = 5
            sigma_x = np.full((hyper_n_divs, hyper_n_divs), 1.8)
            sigma_y = np.full((hyper_n_divs, hyper_n_divs), 1.8)
            rho = np.full((hyper_n_divs, hyper_n_divs), 0.0)
            self._default_reg_psf = RegPSFPrior(
                hyper_im_mea,
                hyper_peak_mea=default_hyper_peak_mea,
                hyper_n_divs=hyper_n_divs,
            ).set(sigma_x=sigma_x, sigma_y=sigma_y, rho=rho)

        if hyper_n_channels is not None:
            self._default_ch_aln = ChannelAlignPrior(hyper_n_channels).set(
                np.zeros((hyper_n_channels, 2))
            )

        self._defaults = {
            "reg_illum": self._default_reg_illum,
            "reg_psf": self._default_reg_psf,
            "ch_aln": self._default_ch_aln,
            "gain_mu": MLEPrior().set(value=7500.0),
            "gain_sigma": MLEPrior().set(value=0.0),
            "bg_mu": MLEPrior().set(value=0.0),
            "bg_sigma": MLEPrior().set(value=100.0),
            "row_k_sigma": MLEPrior().set(value=0.15),
            "p_edman_failure": MLEPrior().set(value=0.06),
            "p_detach": MLEPrior().set(value=0.05),
            "p_bleach": MLEPrior().set(value=0.05),
            "p_non_fluorescent": MLEPrior().set(value=0.07),
        }

        self.priors = {}

    def get_exact(self, request_name):
        """
        Look up exact match or return None
        """
        if request_name in self.priors:
            return Munch(
                request_name=request_name,
                matched_name=request_name,
                prior=self.priors[request_name],
            )

        return None

    def get(self, request_name):
        """
        Search for a request_name looking up the naming tree for a match if needed
        """
        parts = request_name.split(".")
        for i in range(len(parts), 0, -1):
            # SEARCH up the hierarchy
            matched_name = ".".join(parts[0:i])
            if matched_name in self.priors:
                return Munch(
                    request_name=request_name,
                    matched_name=matched_name,
                    prior=self.priors[matched_name],
                )

        # Not found, look in defaults
        matched_name = parts[0]
        if matched_name in self._defaults:
            default = self._defaults[parts[0]]
            if default is not None:
                default.source = "defaults"
                return Munch(
                    request_name=request_name, matched_name=matched_name, prior=default,
                )

        raise KeyError(f"Prior '{request_name}' not resolved")

    def enumerate_names(self):
        return list(self.priors.keys())

    def remove(self, name):
        del self.priors[name]

    def update(self, other, source):
        for key, value in other.priors.items():
            value = copy.deepcopy(value)
            value.source = source
            self.priors[key] = value

    def get_distr(self, request_name):
        """
        Like get() but returns the distribution object only (without the other metadata)
        """
        p = self.get(request_name)
        return p.prior

    def get_sample(self, request_name):
        """
        Like get() but returns a sample from the distribution
        """
        p = self.get(request_name)
        return p.prior.sample()

    def get_mle(self, request_name, **kwargs):
        """
        Like get() but returns the MLE.
        For now, this means that the prior must come from an MLEPrior
        """
        p = self.get(request_name)
        assert isinstance(p.prior, MLEPrior)
        return p.prior.sample()

    def helper_illum_model(self, n_channels):
        """
        Used by nn_v2 to load the C context with cols: gain_mu, gain_sigma, bg_mu, bg_sigma
        One row per channel.

        Note that this is still providing the older functionality of returning
        a MLE value for the parameters. Eventually this is going to be changed
        so that the nn_v2 will have a parametric description of the the prior
        so that it can draw samples instead of using the MLE.
        """

        illum_model = np.zeros((n_channels, 4))
        for ch_i in range(n_channels):
            prior = self.get(f"gain_mu.ch_{ch_i}")
            assert isinstance(prior.prior, MLEPrior)
            gain_mu = prior.prior.sample()
            illum_model[ch_i, 0] = gain_mu

            prior = self.get(f"gain_sigma.ch_{ch_i}")
            assert isinstance(prior.prior, MLEPrior)
            gain_sigma = prior.prior.sample()
            illum_model[ch_i, 1] = gain_sigma

            illum_model[ch_i, 2] = 0.0

            prior = self.get(f"bg_sigma.ch_{ch_i}")
            assert isinstance(prior.prior, MLEPrior)
            bg_sigma = prior.prior.sample()
            illum_model[ch_i, 3] = bg_sigma

        return illum_model

    @classmethod
    def copy(cls, src):
        self = cls()
        self.priors = copy.deepcopy(src.priors)
        return self
Exemplo n.º 8
0
 def it_checks_no_extra():
     test_s = s(
         s.is_dict(elems=dict(a=s.is_int(), b=s.is_int()),
                   no_extras=True))
     with zest.raises(SchemaValidationFailed):
         test_s.validate(dict(a=1, b=1, c=1))
Exemplo n.º 9
0
 def it_allows_elems_in_dict():
     s(s.is_dict(dict(a=s.is_int(noneable=True))))
Exemplo n.º 10
0
 def it_does_not_overwrite_an_existing_dict():
     test_s = s(s.is_kws_r(a=s.is_dict()))
     test = dict(a=dict(b=1))
     # a has a good value, do not overwrite it
     test_s.apply_defaults(defaults=dict(a=None), apply_to=test)
     assert test["a"]["b"] == 1
Exemplo n.º 11
0
 def it_all_required_false_by_default():
     test_s = s(s.is_dict(elems=dict(a=s.is_int(), b=s.is_int())))
     test_s.validate(dict(a=1))
Exemplo n.º 12
0
 def it_validates_default_dict():
     test_s = s(s.is_dict())
     test_s.validate({})
     test_s.validate(dict(a=1, b=2))
     with zest.raises(SchemaValidationFailed):
         test_s.validate(1)
Exemplo n.º 13
0
class SimV1Params(ParamsAndPriors):
    """
    Simulations parameters is and ErrorModel + parameters for sim
    """

    defaults = Munch(
        n_pres=1,
        n_mocks=0,
        n_edmans=1,
        n_samples_train=5_000,
        n_samples_test=1_000,
        dyes=[],
        labels=[],
        random_seed=None,
        train_n_sample_multiplier=
        None,  # This does not appear to be used anywhere. tfb
        allow_train_test_to_be_identical=False,
        enable_ptm_labels=False,
        is_survey=False,
    )

    schema = s(
        s.is_kws_r(
            is_survey=s.is_bool(),
            priors_desc=Priors.priors_desc_schema,
            n_pres=s.is_int(bounds=(0, None)),
            n_mocks=s.is_int(bounds=(0, None)),
            n_edmans=s.is_int(bounds=(0, None)),
            n_samples_train=s.is_int(bounds=(1, None)),
            n_samples_test=s.is_int(bounds=(1, None)),
            dyes=s.is_list(elems=s.is_kws_r(dye_name=s.is_str(),
                                            channel_name=s.is_str())),
            labels=s.is_list(elems=s.is_kws_r(
                aa=s.is_str(),
                dye_name=s.is_str(),
                label_name=s.is_str(),
                ptm_only=s.is_bool(required=False, noneable=True),
            )),
            channels=s.is_dict(required=False),
            random_seed=s.is_int(required=False, noneable=True),
            allow_train_test_to_be_identical=s.is_bool(required=False,
                                                       noneable=True),
            enable_ptm_labels=s.is_bool(required=False, noneable=True),
        ))

    # def copy(self):
    #     # REMOVE everything that _build_join_dfs put in
    #     utils.safe_del(self, "df")
    #     utils.safe_del(self, "by_channel")
    #     utils.safe_del(self, "ch_by_aa")
    #
    #     dst = utils.munch_deep_copy(self, klass_set={SimV1Params})
    #     dst.error_model = ErrorModel(**dst.error_model)
    #     assert isinstance(dst, SimV1Params)
    #     return dst

    def __init__(self, **kwargs):
        super().__init__(source="SimV1Params", **kwargs)
        self._setup_dfs()

    def validate(self):
        super().validate()

        all_dye_names = list(set([d.dye_name for d in self.dyes]))

        # No duplicate dye names
        self._validate(
            len(all_dye_names) == len(self.dyes),
            "The dye list contains a duplicate")

        # No duplicate labels
        self._validate(
            len(list(set(utils.listi(self.labels, "aa")))) == len(self.labels),
            "There is a duplicate label",
        )

        # All labels have a legit dye name
        [
            self._validate(
                label.dye_name in all_dye_names,
                f"Label {label.label_name} does not have a valid matching dye_name",
            ) for label in self.labels
        ]

        # Channel mappings
        mentioned_channels = {dye.channel_name: False for dye in self.dyes}
        if "channels" in self:
            # Validate that channel mapping is complete
            for channel_name, ch_i in self.channels.items():
                self._validate(
                    channel_name in mentioned_channels,
                    f"Channel name '{channel_name}' was not found in dyes",
                )
                mentioned_channels[channel_name] = True

            self._validate(
                all([mentioned
                     for _, mentioned in mentioned_channels.items()]),
                "Not all channels in dyes were enumerated in channels",
            )
        else:
            # No channel mapping: assign them
            self["channels"] = {
                ch_name: i
                for i, ch_name in enumerate(sorted(mentioned_channels.keys()))
            }

    @property
    def n_cycles(self):
        return self.n_pres + self.n_mocks + self.n_edmans

    def channel_names(self):
        return sorted(list(set(utils.listi(self.dyes, "channel_name"))))

    def channel_i_by_name(self):
        channels = self.channel_names()
        return {
            channel_name: channel_i
            for channel_i, channel_name in enumerate(channels)
        }

    @property
    def n_channels(self):
        return len(self.channel_i_by_name().keys())

    @property
    def n_channels_and_cycles(self):
        return self.n_channels, self.n_cycles

    def _setup_dfs(self):
        """
        The error model contains information about the dyes and labels and other terms.
        Those error model parameters are wired together by names which are useful
        for reconciling calibrations.

        But here, these "by name" parameters are all put into a dataframe so that
        they can be indexed by integers.
        """
        dyes_df = pd.DataFrame(self.dyes)
        assert len(dyes_df) > 0

        labels_df = pd.DataFrame(self.labels)
        assert len(labels_df) > 0

        # LOOKUP dye priors
        dye_priors = []
        for dye in self.dyes:
            # SEARCH priors by dye name and if not found by channel
            p_non_fluorescent = self.priors.get_exact(
                f"p_non_fluorescent.{dye.dye_name}")
            if p_non_fluorescent is None:
                p_non_fluorescent = self.priors.get(
                    f"p_non_fluorescent.ch_{dye.channel_name}")

            dye_priors += [
                Munch(
                    dye_name=dye.dye_name,
                    p_non_fluorescent=p_non_fluorescent.prior,
                )
            ]

        dye_priors_df = pd.DataFrame(dye_priors)
        # dye_priors_df: (dye_name, p_non_fluorescent)

        dyes_df = utils.easy_join(dyes_df, dye_priors_df, "dye_name")
        # dyes_df: (dye_name, channel_name, p_non_fluorescent)

        # TODO: LOOKUP label priors
        #       (p_failure_to_bind_aa, p_failure_to_attach_to_dye)

        # LOOKUP channel priors
        ch_priors = pd.DataFrame([
            dict(
                channel_name=channel_name,
                ch_i=ch_i,
                bg_mu=self.priors.get(f"bg_mu.ch_{ch_i}").prior,
                bg_sigma=self.priors.get(f"bg_sigma.ch_{ch_i}").prior,
                gain_mu=self.priors.get(f"gain_mu.ch_{ch_i}").prior,
                gain_sigma=self.priors.get(f"gain_sigma.ch_{ch_i}").prior,
                row_k_sigma=self.priors.get(f"row_k_sigma.ch_{ch_i}").prior,
                p_bleach=self.priors.get(f"p_bleach.ch_{ch_i}").prior,
            ) for channel_name, ch_i in self.channels.items()
        ])
        # ch_priors: (channel_name, ch_i, ...)

        self._channel__priors = (utils.easy_join(
            dyes_df, ch_priors, "channel_name").drop(
                columns=["p_non_fluorescent"]).drop_duplicates().reset_index())
        # self._channel__priors: (
        #    'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name',
        #    'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma',
        # )

        # SANITY check channel__priors
        group_by_ch = self._channel__priors.groupby("ch_i")
        for field in (
                "bg_mu",
                "bg_sigma",
                "gain_mu",
                "gain_sigma",
                "row_k_sigma",
        ):
            assert np.all(group_by_ch[field].nunique() == 1)
        assert "p_non_fluorescent" not in self._channel__priors.columns

        labels_dyes_df = utils.easy_join(labels_df, dyes_df, "dye_name")
        self._dye__label__priors = utils.easy_join(
            labels_dyes_df, ch_priors, "channel_name").reset_index(drop=True)

        # self._dye__label__priors: (
        #     'channel_name', 'dye_name', 'aa', 'label_name',
        #     'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma',
        #     'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach'
        # )

        self._ch_by_aa = {
            row.aa: row.ch_i
            for row in self._dye__label__priors.itertuples()
        }

    def dye__label__priors(self):
        """
        DataFrame(
            'channel_name', 'dye_name', 'aa', 'label_name',
            'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma',
            'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach'
        )
        """
        return self._dye__label__priors

    def channel__priors(self):
        """
        DataFrame(
            'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name',
            'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma',
        )
        """
        return self._channel__priors

    def by_channel(self):
        return self._channel__priors.set_index("ch_i")

    def to_label_list(self):
        """Summarize labels like: ["DE", "C"]"""
        return [
            "".join([
                label.aa for label in self.labels
                if label.dye_name == dye.dye_name
            ]) for dye in self.dyes
        ]

    def to_label_str(self):
        """Summarize labels like: DE,C"""
        return ",".join(self.to_label_list())

    @classmethod
    def construct_from_aa_list(cls, aa_list, **kwargs):
        """
        This is a helper to generate channel when you have a list of aas.
        For example, two channels where ch0 is D&E and ch1 is Y.
        ["DE", "Y"].

        If you pass in an error model, it needs to match channels and labels.
        """

        check.list_or_tuple_t(aa_list, str)

        allowed_aa_mods = ["[", "]"]
        assert all([(aa.isalpha() or aa in allowed_aa_mods) for aas in aa_list
                    for aa in list(aas)])

        dyes = [
            Munch(dye_name=f"dye_{ch}", channel_name=f"ch_{ch}")
            for ch, _ in enumerate(aa_list)
        ]

        # Note the extra for loop because "DE" needs to be split into "D" & "E"
        # which is done by aa_str_to_list() - which also handles PTMs like S[p]
        labels = [
            Munch(
                aa=aa,
                dye_name=f"dye_{ch}",
                label_name=f"label_{ch}",
                ptm_only=False,
            ) for ch, aas in enumerate(aa_list) for aa in aa_str_to_list(aas)
        ]

        return cls(dyes=dyes, labels=labels, **kwargs)
Exemplo n.º 14
0
class SigprocV1Params(Params):
    defaults = dict(
        hat_rad=2,
        iqr_rng=96,
        threshold_abs=1.0,
        channel_indices_for_alignment=None,
        channel_indices_for_peak_finding=None,
        radiometry_channels=None,
        save_debug=False,
        peak_find_n_cycles=4,
        peak_find_start=0,
        radial_filter=None,
        anomaly_iqr_cutoff=95,
        n_fields_limit=None,
        save_full_signal_radmat_npy=False,
    )

    schema = s(
        s.is_kws_r(
            anomaly_iqr_cutoff=s.is_number(noneable=True, bounds=(0, 100)),
            radial_filter=s.is_float(noneable=True, bounds=(0, 1)),
            peak_find_n_cycles=s.is_int(bounds=(1, None), noneable=True),
            peak_find_start=s.is_int(bounds=(0, None), noneable=True),
            save_debug=s.is_bool(),
            hat_rad=s.is_int(bounds=(1, 3)),
            iqr_rng=s.is_number(noneable=True, bounds=(0, 100)),
            threshold_abs=s.is_number(
                bounds=(0, 100)),  # Not sure of a reasonable bound
            channel_indices_for_alignment=s.is_list(s.is_int(), noneable=True),
            channel_indices_for_peak_finding=s.is_list(s.is_int(),
                                                       noneable=True),
            radiometry_channels=s.is_dict(noneable=True),
            n_fields_limit=s.is_int(noneable=True),
            save_full_signal_radmat_npy=s.is_bool(),
        ))

    def validate(self):
        # Note: does not call super because the override_nones is set to false here
        self.schema.apply_defaults(self.defaults,
                                   apply_to=self,
                                   override_nones=False)
        self.schema.validate(self, context=self.__class__.__name__)

        if self.radiometry_channels is not None:
            pat = re.compile(r"[0-9a-z_]+")
            for name, channel_i in self.radiometry_channels.items():
                self._validate(
                    pat.fullmatch(name),
                    "radiometry_channels name must be lower-case alphanumeric (including underscore)",
                )
                self._validate(isinstance(channel_i, int),
                               "channel_i must be an integer")

    def set_radiometry_channels_from_input_channels_if_needed(
            self, n_channels):
        if self.radiometry_channels is None:
            # Assume channels from nd2 manifest
            channels = list(range(n_channels))
            self.radiometry_channels = {f"ch_{ch}": ch for ch in channels}

    @property
    def n_output_channels(self):
        return len(self.radiometry_channels.keys())

    @property
    def n_input_channels(self):
        return len(self.radiometry_channels.keys())

    @property
    def channels_cycles_dim(self):
        # This is a cache set in sigproc_v1.
        # It is a helper for the repeative call:
        # n_outchannels, n_inchannels, n_cycles, dim =
        return self._outchannels_inchannels_cycles_dim

    def _input_channels(self):
        """
        Return a list that converts channel number of the output to the channel of the input
        Example:
            input might have channels ["foo", "bar"]
            the radiometry_channels has: {"bar": 0}]
            Thus this function returns [1] because the 0th output channel is mapped
            to the "1" input channel
        """
        return [
            self.radiometry_channels[name]
            for name in sorted(self.radiometry_channels.keys())
        ]

    # def input_names(self):
    #     return sorted(self.radiometry_channels.keys())

    def output_channel_to_input_channel(self, out_ch):
        return self._input_channels()[out_ch]

    def input_channel_to_output_channel(self, in_ch):
        """Not every input channel necessarily has an output; can return None"""
        return utils.filt_first_arg(self._input_channels(),
                                    lambda x: x == in_ch)
Exemplo n.º 15
0
 def it_fetches_list_elem_type():
     schema = s(s.is_dict(elems=dict(a=s.is_list(s.is_int()))))
     tlf = schema.top_level_fields()
     assert tlf[0][0] == "a" and tlf[0][4] is int
Exemplo n.º 16
0
 def it_checks_no_extra_flase_by_default():
     test_s = s(s.is_dict(elems=dict(a=s.is_int(), b=s.is_int())))
     test_s.validate(dict(a=1, c=1))
Exemplo n.º 17
0
 def it_allows_all_required_to_be_overriden():
     test_s = s(
         s.is_dict(elems=dict(a=s.is_int(required=False)),
                   all_required=True))
     test_s.validate(dict())
Exemplo n.º 18
0
 def it_checks_key_type_str():
     test_s = s(s.is_dict(elems={1: s.is_int()}))
     with zest.raises(SchemaValidationFailed):
         test_s.validate({1: 2})
Exemplo n.º 19
0
 def it_checks_required():
     test_s = s(
         s.is_dict(
             elems=dict(a=s.is_int(required=True), b=s.is_int())))
     with zest.raises(SchemaValidationFailed):
         test_s.validate(dict(b=1))
Exemplo n.º 20
0
class SigprocV2Params(Params):
    defaults = dict(
        radiometry_channels=None,
        n_fields_limit=None,
        save_full_signal_radmat_npy=False,
        # use_cycle_zero_psfs_only=False,
    )

    schema = s(
        s.is_kws_r(
            radiometry_channels=s.is_dict(noneable=True),
            n_fields_limit=s.is_int(noneable=True),
            save_full_signal_radmat_npy=s.is_bool(),
            calibration=s.is_dict(),
            instrument_subject_id=s.is_str(noneable=True),
            # use_cycle_zero_psfs_only=s.is_bool(),
        ))

    def validate(self):
        # Note: does not call super because the override_nones is set to false here
        self.schema.apply_defaults(self.defaults,
                                   apply_to=self,
                                   override_nones=False)
        self.schema.validate(self, context=self.__class__.__name__)

        self.calibration = Calibration(self.calibration)
        if self.instrument_subject_id is not None:
            self.calibration.filter_subject_ids(self.instrument_subject_id)
            if len(self.calibration.keys()) == 0:
                raise ValueError(
                    f"All calibration records removed after filter_subject_ids on subject_id '{self.instrument_subject_id}'"
                )

        assert not self.calibration.has_subject_ids()

        if self.radiometry_channels is not None:
            pat = re.compile(r"[0-9a-z_]+")
            for name, channel_i in self.radiometry_channels.items():
                self._validate(
                    pat.fullmatch(name),
                    "radiometry_channels name must be lower-case alphanumeric (including underscore)",
                )
                self._validate(isinstance(channel_i, int),
                               "channel_i must be an integer")

    def set_radiometry_channels_from_input_channels_if_needed(
            self, n_channels):
        if self.radiometry_channels is None:
            # Assume channels from nd2 manifest
            channels = list(range(n_channels))
            self.radiometry_channels = {f"ch_{ch}": ch for ch in channels}

    @property
    def n_output_channels(self):
        return len(self.radiometry_channels.keys())

    @property
    def n_input_channels(self):
        return len(self.radiometry_channels.keys())

    # @property
    # def channels_cycles_dim(self):
    #     # This is a cache set in sigproc_v1.
    #     # It is a helper for the repetitive call:
    #     # n_outchannels, n_inchannels, n_cycles, dim =
    #     return self._outchannels_inchannels_cycles_dim

    def _input_channels(self):
        """
        Return a list that converts channel number of the output to the channel of the input
        Example:
            input might have channels ["foo", "bar"]
            the radiometry_channels has: {"bar": 0}]
            Thus this function returns [1] because the 0th output channel is mapped
            to the "1" input channel
        """
        return [
            self.radiometry_channels[name]
            for name in sorted(self.radiometry_channels.keys())
        ]

    # def input_names(self):
    #     return sorted(self.radiometry_channels.keys())

    def output_channel_to_input_channel(self, out_ch):
        return self._input_channels()[out_ch]

    def input_channel_to_output_channel(self, in_ch):
        """Not every input channel necessarily has an output; can return None"""
        return utils.filt_first_arg(self._input_channels(),
                                    lambda x: x == in_ch)
Exemplo n.º 21
0
class SimV2Params(ParamsAndPriors):
    # The following constants are repeated in sim_v2.h because it
    # is hard to get constants like this to be shared between
    # the two languages. This shouldn't be a problem as they are stable.
    # TODO: Move these to an import form the pyx
    CycleKindType = np.uint8
    CYCLE_TYPE_PRE = 0
    CYCLE_TYPE_MOCK = 1
    CYCLE_TYPE_EDMAN = 2

    channel__priors__columns = (
        "ch_i",
        "channel_name",
        "bg_mu",
        "bg_sigma",
        "dye_name",
        "gain_mu",
        "gain_sigma",
        "index",
        "p_bleach",
        "row_k_sigma",
    )

    dye__label__priors__columns = (
        "channel_name",
        "dye_name",
        "aa",
        "label_name",
        "ptm_only",
        "p_non_fluorescent",
        "ch_i",
        "bg_mu",
        "bg_sigma",
        "gain_mu",
        "gain_sigma",
        "row_k_sigma",
        "p_bleach",
    )

    defaults = Munch(
        n_pres=1,
        n_mocks=0,
        n_edmans=1,
        n_samples_train=5_000,
        n_samples_test=1_000,
        dyes=[],
        labels=[],
        random_seed=None,
        allow_train_test_to_be_identical=False,
        allow_edman_cterm=False,
        enable_ptm_labels=False,
        is_survey=False,
        train_includes_radmat=False,
        test_includes_dyemat=False,
        dump_debug=False,
        generate_flus=True,
        use_lognormal_model=False,
    )

    schema = s(
        s.is_kws_r(
            priors_desc=Priors.priors_desc_schema,
            is_survey=s.is_bool(),
            n_pres=s.is_int(bounds=(0, None)),
            n_mocks=s.is_int(bounds=(0, None)),
            n_edmans=s.is_int(bounds=(0, None)),
            n_samples_train=s.is_int(bounds=(1, None)),
            n_samples_test=s.is_int(bounds=(1, None)),
            dyes=s.is_list(elems=s.is_kws_r(
                dye_name=s.is_str(),
                channel_name=s.is_str(),
            )),
            labels=s.is_list(elems=s.is_kws_r(
                aa=s.is_str(),
                dye_name=s.is_str(),
                label_name=s.is_str(),
                ptm_only=s.is_bool(required=False, noneable=True),
            )),
            channels=s.is_dict(required=False),
            random_seed=s.is_int(required=False, noneable=True),
            allow_train_test_to_be_identical=s.is_bool(required=False,
                                                       noneable=True),
            allow_edman_cterm=s.is_bool(required=False, noneable=True),
            enable_ptm_labels=s.is_bool(required=False, noneable=True),
            train_includes_radmat=s.is_bool(required=False, noneable=True),
            test_includes_dyemat=s.is_bool(required=False, noneable=True),
            dump_debug=s.is_bool(),
            generate_flus=s.is_bool(),
            use_lognormal_model=s.is_bool(),
        ))

    # def copy(self):
    #     dst = utils.munch_deep_copy(self, klass_set={SimV2Params})
    #     assert isinstance(dst, SimV2Params)
    #     return dst

    def __init__(self, **kwargs):
        # _skip_setup_dfs is True in fixture mode
        super().__init__(source="SimV2Params", **kwargs)
        self._setup_dfs()

    def validate(self):
        super().validate()

        all_dye_names = list(set([d.dye_name for d in self.dyes]))

        # No duplicate dye names
        self._validate(
            len(all_dye_names) == len(self.dyes),
            "The dye list contains a duplicate")

        # No duplicate labels
        self._validate(
            len(list(set(utils.listi(self.labels, "aa")))) == len(self.labels),
            "There is a duplicate label in the label_set",
        )

        # All labels have a legit dye name
        [
            self._validate(
                label.dye_name in all_dye_names,
                f"Label {label.label_name} does not have a valid matching dye_name",
            ) for label in self.labels
        ]

        # Channel mappings
        mentioned_channels = {dye.channel_name: False for dye in self.dyes}
        if "channels" in self:
            # Validate that channel mapping is complete
            for channel_name, ch_i in self.channels.items():
                self._validate(
                    channel_name in mentioned_channels,
                    f"Channel name '{channel_name}' was not found in dyes",
                )
                mentioned_channels[channel_name] = True

            self._validate(
                all([mentioned
                     for _, mentioned in mentioned_channels.items()]),
                "Not all channels in dyes were enumerated in channels",
            )
        else:
            # No channel mapping: assign them
            self["channels"] = {
                ch_name: i
                for i, ch_name in enumerate(sorted(mentioned_channels.keys()))
            }

    @property
    def n_cycles(self):
        return self.n_pres + self.n_mocks + self.n_edmans

    def channel_names(self):
        return [
            ch_name for ch_name, _ in sorted(self.channels.items(),
                                             key=lambda item: item[1])
        ]

    def ch_i_by_name(self):
        return self.channels

    @property
    def n_channels(self):
        # if self.is_photobleaching_run:
        #     return 1
        return len(self.channels)

    @property
    def n_channels_and_cycles(self):
        return self.n_channels, self.n_cycles

    def _setup_dfs(self):
        """
        Assemble all of the priors into several dataframes indexed differently.
        (Call after validate)

        * self.channel__priors:
            ch_i,
            ch_name,
            bg_mu,
            bg_sigma,
            gain_mu,
            gain_sigma,
            row_k_sigma,
            p_bleach
            --> Note, does NOT have p_non_fluorescent because this is a dye property

        * self.dye__label__priors:
            aa,
            label_name,
            dye_name,
            ch_i,
            ch_name,
            bg_mu,
            bg_sigma,
            gain_mu,
            gain_sigma,
            row_k_sigma,
            p_bleach
            p_non_fluorescent,
        """

        # if self.is_photobleaching_run:
        #     # Not sure what these should be yet
        #     # self._ch_by_aa = {}
        #     # self._channel__priors = pd.DataFrame(columns=self.channel__priors__columns)
        #     # self._dye__label__priors = pd.DataFrame(columns=self.dye__label__priors__columns)
        #     self.dyes = [Munch(dye_name="zero", channel_name="zero")]
        #     self.channels = Munch(zero=0)
        #     self.labels = [
        #         dict(aa=".", dye_name="zero", label_name="zero", ptm_only=False)
        #     ]

        labels_df = pd.DataFrame(self.labels)
        # labels_df: (aa, dye_name, label_name, ptm_only)
        # assert len(labels_df) > 0

        dyes_df = pd.DataFrame(self.dyes)
        # dyes_df: (dye_name, channel_name)
        # assert len(dyes_df) > 0

        # LOOKUP dye priors
        dye_priors = []
        for dye in self.dyes:
            # SEARCH priors by dye name and if not found by channel
            p_non_fluorescent = self.priors.get_exact(
                f"p_non_fluorescent.{dye.dye_name}")
            if p_non_fluorescent is None:
                p_non_fluorescent = self.priors.get(
                    f"p_non_fluorescent.ch_{dye.channel_name}")

            dye_priors += [
                Munch(
                    dye_name=dye.dye_name,
                    p_non_fluorescent=p_non_fluorescent.prior,
                )
            ]

        dye_priors_df = pd.DataFrame(dye_priors)
        # dye_priors_df: (dye_name, p_non_fluorescent)

        dyes_df = utils.easy_join(dyes_df, dye_priors_df, "dye_name")
        # dyes_df: (dye_name, channel_name, p_non_fluorescent)

        # TODO: LOOKUP label priors
        #       (p_failure_to_bind_aa, p_failure_to_attach_to_dye)

        # LOOKUP channel priors
        ch_priors = pd.DataFrame([
            dict(
                channel_name=channel_name,
                ch_i=ch_i,
                bg_mu=self.priors.get(f"bg_mu.ch_{ch_i}").prior,
                bg_sigma=self.priors.get(f"bg_sigma.ch_{ch_i}").prior,
                gain_mu=self.priors.get(f"gain_mu.ch_{ch_i}").prior,
                gain_sigma=self.priors.get(f"gain_sigma.ch_{ch_i}").prior,
                row_k_sigma=self.priors.get(f"row_k_sigma.ch_{ch_i}").prior,
                p_bleach=self.priors.get(f"p_bleach.ch_{ch_i}").prior,
            ) for channel_name, ch_i in self.channels.items()
        ])
        # ch_priors: (channel_name, ch_i, ...)

        self._channel__priors = (utils.easy_join(
            dyes_df, ch_priors, "channel_name").drop(
                columns=["p_non_fluorescent"]).drop_duplicates().reset_index())
        # self._channel__priors: (
        #    'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name',
        #    'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma',
        # )

        # SANITY check channel__priors
        group_by_ch = self._channel__priors.groupby("ch_i")
        for field in (
                "bg_mu",
                "bg_sigma",
                "gain_mu",
                "gain_sigma",
                "row_k_sigma",
        ):
            assert np.all(group_by_ch[field].nunique() == 1)
        assert "p_non_fluorescent" not in self._channel__priors.columns

        labels_dyes_df = utils.easy_join(labels_df, dyes_df, "dye_name")
        self._dye__label__priors = utils.easy_join(
            labels_dyes_df, ch_priors, "channel_name").reset_index(drop=True)

        # self._dye__label__priors: (
        #     'channel_name', 'dye_name', 'aa', 'label_name',
        #     'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma',
        #     'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach'
        # )

        self._ch_by_aa = {
            row.aa: row.ch_i
            for row in self._dye__label__priors.itertuples()
        }

    def ch_by_aa(self):
        return self._ch_by_aa

    def dye__label__priors(self):
        """
        DataFrame(
            'channel_name', 'dye_name', 'aa', 'label_name',
            'ptm_only', 'p_non_fluorescent', 'ch_i', 'bg_mu', 'bg_sigma',
            'gain_mu', 'gain_sigma', 'row_k_sigma', 'p_bleach'
        )
        """
        return self._dye__label__priors

    def channel__priors(self):
        """
        DataFrame(
            'ch_i', 'channel_name', 'bg_mu', 'bg_sigma', 'dye_name',
            'gain_mu', 'gain_sigma', 'index', 'p_bleach', 'row_k_sigma',
        )
        """
        return self._channel__priors

    def by_channel(self):
        return self._channel__priors.set_index("ch_i")

    def to_label_list(self):
        """Summarize labels like: ["DE", "C"]"""
        return [
            "".join([
                label.aa for label in self.labels
                if label.dye_name == dye.dye_name
            ]) for dye in self.dyes
        ]

    def to_label_str(self):
        """Summarize labels like: DE,C"""
        return ",".join(self.to_label_list())

    def cycles_array(self):
        cycles = np.zeros((self.n_cycles, ), dtype=self.CycleKindType)
        i = 0
        for _ in range(self.n_pres):
            cycles[i] = self.CYCLE_TYPE_PRE
            i += 1
        for _ in range(self.n_mocks):
            cycles[i] = self.CYCLE_TYPE_MOCK
            i += 1
        for _ in range(self.n_edmans):
            cycles[i] = self.CYCLE_TYPE_EDMAN
            i += 1
        return cycles

    def pcbs(self, pep_seq_df):
        """
        pcb stands for (p)ep_i, (c)hannel_i, (b)right_probability

        This is a structure that is liek a "flu" but with an extra bright probability.

        Each peptide has a row for each amino acid
            That row has a columns (pep_i, ch_i, p_bright)
            And it will have np.nan for ch_i and p_bright **IF THERE IS NO LABEL**

        bright_probability is the inverse of all the ways a dye can fail to be visible
        ie the probability that a dye is active.

        pep_seq_df: Any DataFrame with an "aa" column

        Returns:
            contiguous ndarray(:, 3) where there 3 columns are:
                pep_i, ch_i, p_bright
        """
        labelled_pep_df = pep_seq_df.join(
            self.dye__label__priors().set_index("aa"), on="aa", how="left")

        # p_bright = is the product of (1.0 - ) all the ways the dye can fail to be visible.
        labelled_pep_df["p_bright"] = (
            # TODO: Sim needs to be converted to use priors sampling
            #       at which point this function needs to be refactored
            #       so that the parameters of the priors can be sampled in C.
            1.0 - np.array([
                i.sample() if isinstance(i, Prior) else np.nan
                for i in labelled_pep_df.p_non_fluorescent
            ])
            # TODO: Add label priors
            # * (1.0 - labelled_pep_df.p_failure_to_attach_to_dye)
            # * (1.0 - labelled_pep_df.p_failure_to_bind_aa)
        )

        labelled_pep_df.sort_values(by=["pep_i", "pep_offset_in_pro"],
                                    inplace=True)
        return np.ascontiguousarray(
            labelled_pep_df[["pep_i", "ch_i", "p_bright"]].values)

    @classmethod
    def from_aa_list_fixture(cls, aa_list, priors=None, **kwargs):
        """
        This is a helper to generate channel when you have a list of aas.
        For example, two channels where ch0 is D&E and ch1 is Y.
        ["DE", "Y"].
        """

        check.list_or_tuple_t(aa_list, str)

        allowed_aa_mods = ["[", "]"]
        assert all([(aa.isalpha() or aa in allowed_aa_mods) for aas in aa_list
                    for aa in list(aas)])

        dyes = [
            Munch(dye_name=f"dye_{ch}", channel_name=f"ch_{ch}")
            for ch, _ in enumerate(aa_list)
        ]

        # Note the extra for loop because "DE" needs to be split into "D" & "E"
        # which is done by aa_str_to_list() - which also handles PTMs like S[p]
        labels = [
            Munch(
                aa=aa,
                dye_name=f"dye_{ch}",
                label_name=f"label_{ch}",
                ptm_only=False,
            ) for ch, aas in enumerate(aa_list) for aa in aa_str_to_list(aas)
        ]

        return cls(dyes=dyes, labels=labels, priors=priors, **kwargs)