예제 #1
0
    def __init__(self,
                 alphabet=DNA,
                 neutral_alphabet='N',
                 neutral_value=0.25,
                 dtype=None,
                 alphabet_axis=1,
                 dummy_axis=None):
        # make sure the alphabet axis and the dummy axis are valid:
        if dummy_axis is not None:
            if alphabet_axis == dummy_axis:
                raise ValueError("dummy_axis can't be the same as dummy_axis")
            if not (dummy_axis >= 0 and dummy_axis <= 2):
                raise ValueError("dummy_axis can be either 0,1 or 2")
        assert alphabet_axis >= 0 and (alphabet_axis < 2 or
                                       (alphabet_axis <= 2
                                        and dummy_axis is not None))

        self.alphabet_axis = alphabet_axis
        self.dummy_axis = dummy_axis
        self.alphabet = parse_alphabet(alphabet)
        self.dtype = parse_dtype(dtype)
        self.neutral_alphabet = neutral_alphabet
        self.neutral_value = neutral_value

        # set the transform parameters correctly
        if dummy_axis is not None and dummy_axis < 2:
            # dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
            existing_alphabet_axis = 2
        else:
            # alphabet axis stayed the same
            existing_alphabet_axis = 1

        # check if no swapping needed
        if existing_alphabet_axis == self.alphabet_axis:
            self.alphabet_axis = None

        # how to transform the input
        self.transform = Compose([
            OneHot(self.alphabet,
                   neutral_alphabet=self.neutral_alphabet,
                   neutral_value=self.neutral_value,
                   dtype=self.dtype),  # one-hot-encode
            DummyAxis(self.dummy_axis),  # optionally inject the dummy axis
            SwapAxes(existing_alphabet_axis,
                     self.alphabet_axis),  # put the alphabet axis elsewhere
        ])
예제 #2
0
def test_parse_alphabet():
    assert parse_alphabet(['A', 'C']) == ['A', 'C']
    assert parse_alphabet('AC') == ['A', 'C']