def init_by_rep(df, alpha_val, N, param_file): if param_file == None: params = load_params_1900(None) DE_model = get_top_model( df, [alpha_val], N ) # choose unirep representation, alpha=1e-3, and 96 training mutants else: # if we want to use an evotuned representation: params = load_params_1900(param_file) DE_model = get_top_model( df, [alpha_val], N ) # choose eunirep representation, alpha=1e-3, and 96 training mutants return params, DE_model
def test_load_params_1900(): """ Make sure that parameters to be passed to the mlstm1900 have the right shapes. """ params = load_params_1900() validate_mLSTM1900_params(params)
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): from jax_unirep.utils import load_params_1900 if device: raise NotImplementedError("UniRep does not allow configuring the device") super().__init__(device, **kwargs) self.params = load_params_1900()
def test_rep_arbitrary_lengths(seqs, expected): params = load_params_1900() with expected: assert rep_arbitrary_lengths(seqs, params) is not None if expected == does_not_raise(): h_final, c_final, h_avg = rep_arbitrary_lengths(seqs, params) assert h_final.shape == (len(seqs), 1900) assert c_final.shape == (len(seqs), 1900) assert h_avg.shape == (len(seqs), 1900)
def test_mLSTM1900_batch(): """ Given one fake embedded sequence, ensure that we get out _an_ output from mLSTM1900. """ emb = load_embedding_1900() x = get_embedding("TEST", emb) params = load_params_1900() h_final, c_final, h = mLSTM1900_batch(params, x) assert h.shape == (x.shape[0], 1900)
def test_mLSTM1900_step(): """ Given fake data of the correct input shapes, make sure that the output shapes are also correct. """ params = load_params_1900() x_t = npr.normal(size=(1, 10)) h_t = np.zeros(shape=(1, 1900)) c_t = np.zeros(shape=(1, 1900)) carry = (h_t, c_t) (h_t, c_t), _ = mLSTM1900_step(params, carry, x_t) assert h_t.shape == (1, 1900) assert c_t.shape == (1, 1900)
def test_mLSTM1900(data): params = load_params_1900() length = data.draw(st.integers(min_value=1, max_value=10)) sequence = data.draw( st.text( alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ", min_size=length, max_size=length, ), ) embedding = load_embedding_1900() x = get_embedding(sequence, embedding) init_fun, apply_fun = mLSTM1900(output_dim=1900) output_shape, params = init_fun(rng, (length, 10)) h_final, c_final, outputs = apply_fun(params=params, inputs=x) assert output_shape == (length, 1900) validate_mLSTM1900_params(params) assert outputs.shape == (length + 1, 1900)
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): from jax_unirep.utils import load_params_1900 from jax_unirep.featurize import apply_fun self._params = load_params_1900() self._apply_fun = apply_fun # For v2 # https://github.com/ElArkk/jax-unirep/issues/107 # from jax_unirep.utils import load_params # from jax_unirep.layers import mLSTM # from jax_unirep.utils import validate_mLSTM_params # self._params = load_params()[1] # _, self._apply_fun = mLSTM(output_dim=self.embedding_dimension) # validate_mLSTM_params(self._params, n_outputs=self.embedding_dimension) if device: raise NotImplementedError("UniRep does not allow configuring the device") super().__init__(device, **kwargs)
def test_mLSTM1900_AvgHidden(data): params = load_params_1900() length = data.draw(st.integers(min_value=1, max_value=10)) sequence = data.draw( st.text( alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ", min_size=length, max_size=length, ), ) embedding = load_embedding_1900() x = get_embedding(sequence, embedding) init_fun, apply_fun = stax.serial( mLSTM1900(output_dim=1900), mLSTM1900_AvgHidden(output_dim=1900), ) output_shape, params = init_fun(rng, (length, 10)) h_avg = apply_fun(params=params, inputs=x) assert output_shape == (1900,) validate_mLSTM1900_params(params[0]) assert params[1] == () assert h_avg.shape == (1900,)