Пример #1
0
def init_by_rep(df, alpha_val, N, param_file):

    if param_file == None:
        params = load_params_1900(None)
        DE_model = get_top_model(
            df, [alpha_val], N
        )  # choose unirep representation, alpha=1e-3, and 96 training mutants

    else:  # if we want to use an evotuned representation:
        params = load_params_1900(param_file)
        DE_model = get_top_model(
            df, [alpha_val], N
        )  # choose eunirep representation, alpha=1e-3, and 96 training mutants

    return params, DE_model
Пример #2
0
def test_load_params_1900():
    """
    Make sure that parameters to be passed to 
    the mlstm1900 have the right shapes.
    """
    params = load_params_1900()
    validate_mLSTM1900_params(params)
Пример #3
0
    def __init__(self, device: Union[None, str, torch.device] = None, **kwargs):
        from jax_unirep.utils import load_params_1900

        if device:
            raise NotImplementedError("UniRep does not allow configuring the device")
        super().__init__(device, **kwargs)

        self.params = load_params_1900()
Пример #4
0
def test_rep_arbitrary_lengths(seqs, expected):
    params = load_params_1900()

    with expected:
        assert rep_arbitrary_lengths(seqs, params) is not None

    if expected == does_not_raise():
        h_final, c_final, h_avg = rep_arbitrary_lengths(seqs, params)
        assert h_final.shape == (len(seqs), 1900)
        assert c_final.shape == (len(seqs), 1900)
        assert h_avg.shape == (len(seqs), 1900)
Пример #5
0
def test_mLSTM1900_batch():
    """
    Given one fake embedded sequence,
    ensure that we get out _an_ output from mLSTM1900.
    """
    emb = load_embedding_1900()
    x = get_embedding("TEST", emb)

    params = load_params_1900()

    h_final, c_final, h = mLSTM1900_batch(params, x)
    assert h.shape == (x.shape[0], 1900)
Пример #6
0
def test_mLSTM1900_step():
    """
    Given fake data of the correct input shapes,
    make sure that the output shapes are also correct.
    """
    params = load_params_1900()

    x_t = npr.normal(size=(1, 10))
    h_t = np.zeros(shape=(1, 1900))
    c_t = np.zeros(shape=(1, 1900))

    carry = (h_t, c_t)

    (h_t, c_t), _ = mLSTM1900_step(params, carry, x_t)
    assert h_t.shape == (1, 1900)
    assert c_t.shape == (1, 1900)
Пример #7
0
def test_mLSTM1900(data):
    params = load_params_1900()
    length = data.draw(st.integers(min_value=1, max_value=10))
    sequence = data.draw(
        st.text(
            alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ",
            min_size=length,
            max_size=length,
        ),
    )
    embedding = load_embedding_1900()
    x = get_embedding(sequence, embedding)
    init_fun, apply_fun = mLSTM1900(output_dim=1900)
    output_shape, params = init_fun(rng, (length, 10))
    h_final, c_final, outputs = apply_fun(params=params, inputs=x)
    assert output_shape == (length, 1900)
    validate_mLSTM1900_params(params)
    assert outputs.shape == (length + 1, 1900)
Пример #8
0
    def __init__(self, device: Union[None, str, torch.device] = None, **kwargs):
        from jax_unirep.utils import load_params_1900
        from jax_unirep.featurize import apply_fun

        self._params = load_params_1900()
        self._apply_fun = apply_fun

        # For v2
        # https://github.com/ElArkk/jax-unirep/issues/107
        # from jax_unirep.utils import load_params
        # from jax_unirep.layers import mLSTM
        # from jax_unirep.utils import validate_mLSTM_params
        # self._params = load_params()[1]
        # _, self._apply_fun = mLSTM(output_dim=self.embedding_dimension)
        # validate_mLSTM_params(self._params, n_outputs=self.embedding_dimension)

        if device:
            raise NotImplementedError("UniRep does not allow configuring the device")
        super().__init__(device, **kwargs)
Пример #9
0
def test_mLSTM1900_AvgHidden(data):
    params = load_params_1900()
    length = data.draw(st.integers(min_value=1, max_value=10))
    sequence = data.draw(
        st.text(
            alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ",
            min_size=length,
            max_size=length,
        ),
    )
    embedding = load_embedding_1900()
    x = get_embedding(sequence, embedding)
    init_fun, apply_fun = stax.serial(
        mLSTM1900(output_dim=1900), mLSTM1900_AvgHidden(output_dim=1900),
    )
    output_shape, params = init_fun(rng, (length, 10))
    h_avg = apply_fun(params=params, inputs=x)
    assert output_shape == (1900,)
    validate_mLSTM1900_params(params[0])
    assert params[1] == ()
    assert h_avg.shape == (1900,)