def __init__(self, alphabet=DNA, neutral_alphabet='N', neutral_value=0.25, dtype=None, alphabet_axis=1, dummy_axis=None): # make sure the alphabet axis and the dummy axis are valid: if dummy_axis is not None: if alphabet_axis == dummy_axis: raise ValueError("dummy_axis can't be the same as dummy_axis") if not (dummy_axis >= 0 and dummy_axis <= 2): raise ValueError("dummy_axis can be either 0,1 or 2") assert alphabet_axis >= 0 and (alphabet_axis < 2 or (alphabet_axis <= 2 and dummy_axis is not None)) self.alphabet_axis = alphabet_axis self.dummy_axis = dummy_axis self.alphabet = parse_alphabet(alphabet) self.dtype = parse_dtype(dtype) self.neutral_alphabet = neutral_alphabet self.neutral_value = neutral_value # set the transform parameters correctly if dummy_axis is not None and dummy_axis < 2: # dummy axis is added somewhere in the middle, so the alphabet axis is at the end now existing_alphabet_axis = 2 else: # alphabet axis stayed the same existing_alphabet_axis = 1 # check if no swapping needed if existing_alphabet_axis == self.alphabet_axis: self.alphabet_axis = None # how to transform the input self.transform = Compose([ OneHot(self.alphabet, neutral_alphabet=self.neutral_alphabet, neutral_value=self.neutral_value, dtype=self.dtype), # one-hot-encode DummyAxis(self.dummy_axis), # optionally inject the dummy axis SwapAxes(existing_alphabet_axis, self.alphabet_axis), # put the alphabet axis elsewhere ])
def test_parse_alphabet(): assert parse_alphabet(['A', 'C']) == ['A', 'C'] assert parse_alphabet('AC') == ['A', 'C']