def lstm_sample_fun(rng_key, lstm, target_lens, score_fun, ref_sets=None): """Draw samples from LSTM model within Pytorch""" samp_mode = 'fixed' bi_exp = lstm.bi_dir this_alph = alph + ALPHABET['bos_eos'] lstm = lstm.eval() examp_samps = {} if bi_exp: corr_frac = {} for samp_l in target_lens: ref_s = ref_sets[samp_l] rng_key, key = jax.random.split(rng_key) # TODO: Finish up better bidirectional sampling code, including # (a) deal with BOS/EOS, (b) properly put samp_chars in # ref_set strings raise NotImplementedError ref_sets = [s[1:-1] for s in to_string(ref_s, this_alph)] samp_chars = lstm.sample(key, alph, samp_mode='completion', ref_strset=ref_s) # BOS and EOS should never be sampled, so replace those with # incorrect strings samples = [')(' if ('^' in s or '$' in s) else s for s in samples] corr_frac[samp_l] = 100 * score_fun(samples) examp_samps[samp_l] = samples[:10] print(f"Correct frac len={samp_l}: {corr_frac[samp_l]:.1f}%") print(f"Replacement examples:{examp_samps[samp_l]}\n") else: corr_frac = {} for samp_l in target_lens: rng_key, key = jax.random.split(rng_key) samples = lstm.sample(key, this_alph, samp_mode=samp_mode, num_samps=samp_size, samp_len=samp_l) score = score_fun(samples) corr_frac[samp_l] = 100 * score examp_samps[samp_l] = samples[:10] print(f"Correct frac len={samp_l}: {100 * score:.1f}%") print(f"Example samples: {examp_samps[samp_l]}\n") return corr_frac
def mps_sample_fun(rng_key, mps, score_fun, alph, ref_sets=None): """Draw samples from MPS model within JAX""" from sampler import draw_samples samp_lens = EXP_ARGS['samp_lens'] samp_size = EXP_ARGS['samp_size'] bi_exp = EXP_ARGS['bi_exp'] examp_samps = {} if bi_exp: corr_frac = {} for samp_l in samp_lens: ref_s = ref_sets[samp_l] ref_sets = to_string(ref_s, alph) rng_key, key = jax.random.split(rng_key) samp_chars = fill_in_blanks(key, mps, alphabet=alph, ref_strset=ref_s) # TODO: Fold this code into fill_in_blanks # Generate validation strings with each character replaced by # suggested character from samp_chars samples = [ s[:i] + c + s[i + 1:] for s, cs in zip(ref_sets, samp_chars) for i, c in enumerate(cs) ] corr_frac[samp_l] = 100 * score_fun(samples) examp_samps[samp_l] = samples[:10] m_print(f"Correct frac len={samp_l}: {corr_frac[samp_l]:.1f}%") m_print(f"Replacement examples: {samples[:10]}\n") else: corr_frac = {} for samp_l in samp_lens: rng_key, key = jax.random.split(rng_key) samples = draw_samples(key, mps, alphabet=alph, num_samps=samp_size, samp_len=samp_l) score = score_fun(samples) corr_frac[samp_l] = 100 * score examp_samps[samp_l] = samples[:10] m_print(f"Correct frac len={samp_l}: {100 * score:.1f}%") m_print(f"Example samples: {samples[:10]}\n") return corr_frac, examp_samps
def lstm_sample_fun(rng_key, lstm, score_fun, alph, ref_sets=None): """Draw samples from LSTM model within Pytorch""" samp_lens = EXP_ARGS['samp_lens'] samp_size = EXP_ARGS['samp_size'] samp_mode = EXP_ARGS['samp_mode'] bi_exp = lstm.bi_dir lstm = lstm.eval() examp_samps = {} if bi_exp: corr_frac = {} for samp_l in samp_lens: ref_s = ref_sets[samp_l] rng_key, key = jax.random.split(rng_key) # ref_strs = [s[1:-1] for s in to_string(ref_s, alph)] samples = lstm.sample(key, alph, samp_mode='completion', ref_strset=ref_s) # BOS and EOS should never be sampled # assert not any(('^' in s or '$' in s) for s in samples) samples = [')(' if ('^' in s or '$' in s) else s for s in samples] corr_frac[samp_l] = 100 * score_fun(samples) examp_samps[samp_l] = samples[:10] m_print(f"Correct frac len={samp_l}: {corr_frac[samp_l]:.1f}%") m_print(f"Replacement examples:{examp_samps[samp_l]}\n") else: corr_frac = {} for samp_l in samp_lens: rng_key, key = jax.random.split(rng_key) samples = lstm.sample(key, alph, samp_mode=samp_mode, num_samps=samp_size, samp_len=samp_l) score = score_fun(samples) corr_frac[samp_l] = 100 * score examp_samps[samp_l] = samples[:10] m_print(f"Correct frac len={samp_l}: {100 * score:.1f}%") m_print(f"Example samples: {examp_samps[samp_l]}\n") return corr_frac, examp_samps