def sample_uniformly(self, num_samples, min_seq_len=None, max_seq_len=None, pad=True, seed=None): """Samples valid integer-encoded sequences from the domain. Args: num_samples: The number of samples. min_seq_len: The minimum sequence length of samples (inclusive). max_seq_len: The maximum sequence length of samples (inclusive). pad: Whether to pad sequences to the maximum length. seed: Optional seed of the random number generator. Returns: A list with `num_samples` samples. """ if min_seq_len is None: min_seq_len = self.min_length if max_seq_len is None: max_seq_len = self.length random_state = utils.get_random_state(seed) valid_token_ids = np.delete(self.vocab.token_ids, self.vocab.eos) lengths = random_state.randint(min_seq_len, max_seq_len + 1, num_samples) seqs = [ random_state.choice(valid_token_ids, length) for length in lengths ] if pad: seqs = seq_utils.pad_sequences(seqs, self.length, self.vocab.eos) return seqs
def sample_uniformly(self, num_samples, seed=None): random_state = utils.get_random_state(seed) return np.int32( random_state.randint(size=[num_samples, self.length], low=0, high=self.vocab_size))