def sample_uniformly(self, num_samples, min_seq_len=None, max_seq_len=None, pad=True, seed=None): """Samples valid integer-encoded sequences from the domain. Args: num_samples: The number of samples. min_seq_len: The minimum sequence length of samples (inclusive). max_seq_len: The maximum sequence length of samples (inclusive). pad: Whether to pad sequences to the maximum length. seed: Optional seed of the random number generator. Returns: A list with `num_samples` samples. """ if min_seq_len is None: min_seq_len = self.min_length if max_seq_len is None: max_seq_len = self.length random_state = utils.get_random_state(seed) valid_token_ids = np.delete(self.vocab.token_ids, self.vocab.eos) lengths = random_state.randint(min_seq_len, max_seq_len + 1, num_samples) seqs = [ random_state.choice(valid_token_ids, length) for length in lengths ] if pad: seqs = seq_utils.pad_sequences(seqs, self.length, self.vocab.eos) return seqs
def encode(self, sequences, pad=True): """Integer-encodes sequences and optionally pads them.""" encoded = [self.vocab.encode(seq) for seq in sequences] if pad: encoded = seq_utils.pad_sequences(encoded, self.length, self.vocab.eos) return encoded