Example #1
0
def lstm_sample_fun(rng_key, lstm, target_lens, score_fun, ref_sets=None):
    """Draw samples from LSTM model within Pytorch"""
    samp_mode = 'fixed'
    bi_exp = lstm.bi_dir
    this_alph = alph + ALPHABET['bos_eos']
    lstm = lstm.eval()
    examp_samps = {}
    if bi_exp:
        corr_frac = {}
        for samp_l in target_lens:
            ref_s = ref_sets[samp_l]
            rng_key, key = jax.random.split(rng_key)

            # TODO: Finish up better bidirectional sampling code, including
            #       (a) deal with BOS/EOS, (b) properly put samp_chars in
            #       ref_set strings
            raise NotImplementedError
            ref_sets = [s[1:-1] for s in to_string(ref_s, this_alph)]
            samp_chars = lstm.sample(key,
                                     alph,
                                     samp_mode='completion',
                                     ref_strset=ref_s)

            # BOS and EOS should never be sampled, so replace those with
            # incorrect strings
            samples = [')(' if ('^' in s or '$' in s) else s for s in samples]
            corr_frac[samp_l] = 100 * score_fun(samples)
            examp_samps[samp_l] = samples[:10]
            print(f"Correct frac len={samp_l}:  {corr_frac[samp_l]:.1f}%")
            print(f"Replacement examples:{examp_samps[samp_l]}\n")

    else:
        corr_frac = {}
        for samp_l in target_lens:
            rng_key, key = jax.random.split(rng_key)
            samples = lstm.sample(key,
                                  this_alph,
                                  samp_mode=samp_mode,
                                  num_samps=samp_size,
                                  samp_len=samp_l)
            score = score_fun(samples)
            corr_frac[samp_l] = 100 * score
            examp_samps[samp_l] = samples[:10]
            print(f"Correct frac len={samp_l}:  {100 * score:.1f}%")
            print(f"Example samples: {examp_samps[samp_l]}\n")

    return corr_frac
Example #2
0
def mps_sample_fun(rng_key, mps, score_fun, alph, ref_sets=None):
    """Draw samples from MPS model within JAX"""
    from sampler import draw_samples
    samp_lens = EXP_ARGS['samp_lens']
    samp_size = EXP_ARGS['samp_size']
    bi_exp = EXP_ARGS['bi_exp']

    examp_samps = {}
    if bi_exp:
        corr_frac = {}
        for samp_l in samp_lens:
            ref_s = ref_sets[samp_l]
            ref_sets = to_string(ref_s, alph)
            rng_key, key = jax.random.split(rng_key)
            samp_chars = fill_in_blanks(key,
                                        mps,
                                        alphabet=alph,
                                        ref_strset=ref_s)
            # TODO: Fold this code into fill_in_blanks
            # Generate validation strings with each character replaced by
            # suggested character from samp_chars
            samples = [
                s[:i] + c + s[i + 1:] for s, cs in zip(ref_sets, samp_chars)
                for i, c in enumerate(cs)
            ]
            corr_frac[samp_l] = 100 * score_fun(samples)
            examp_samps[samp_l] = samples[:10]
            m_print(f"Correct frac len={samp_l}:  {corr_frac[samp_l]:.1f}%")
            m_print(f"Replacement examples: {samples[:10]}\n")
    else:
        corr_frac = {}
        for samp_l in samp_lens:
            rng_key, key = jax.random.split(rng_key)
            samples = draw_samples(key,
                                   mps,
                                   alphabet=alph,
                                   num_samps=samp_size,
                                   samp_len=samp_l)
            score = score_fun(samples)
            corr_frac[samp_l] = 100 * score
            examp_samps[samp_l] = samples[:10]
            m_print(f"Correct frac len={samp_l}:  {100 * score:.1f}%")
            m_print(f"Example samples: {samples[:10]}\n")

    return corr_frac, examp_samps
Example #3
0
def lstm_sample_fun(rng_key, lstm, score_fun, alph, ref_sets=None):
    """Draw samples from LSTM model within Pytorch"""
    samp_lens = EXP_ARGS['samp_lens']
    samp_size = EXP_ARGS['samp_size']
    samp_mode = EXP_ARGS['samp_mode']
    bi_exp = lstm.bi_dir
    lstm = lstm.eval()
    examp_samps = {}
    if bi_exp:
        corr_frac = {}
        for samp_l in samp_lens:
            ref_s = ref_sets[samp_l]
            rng_key, key = jax.random.split(rng_key)

            # ref_strs = [s[1:-1] for s in to_string(ref_s, alph)]
            samples = lstm.sample(key,
                                  alph,
                                  samp_mode='completion',
                                  ref_strset=ref_s)

            # BOS and EOS should never be sampled
            # assert not any(('^' in s or '$' in s) for s in samples)
            samples = [')(' if ('^' in s or '$' in s) else s for s in samples]
            corr_frac[samp_l] = 100 * score_fun(samples)
            examp_samps[samp_l] = samples[:10]
            m_print(f"Correct frac len={samp_l}:  {corr_frac[samp_l]:.1f}%")
            m_print(f"Replacement examples:{examp_samps[samp_l]}\n")

    else:
        corr_frac = {}
        for samp_l in samp_lens:
            rng_key, key = jax.random.split(rng_key)
            samples = lstm.sample(key,
                                  alph,
                                  samp_mode=samp_mode,
                                  num_samps=samp_size,
                                  samp_len=samp_l)
            score = score_fun(samples)
            corr_frac[samp_l] = 100 * score
            examp_samps[samp_l] = samples[:10]
            m_print(f"Correct frac len={samp_l}:  {100 * score:.1f}%")
            m_print(f"Example samples: {examp_samps[samp_l]}\n")

    return corr_frac, examp_samps