Ejemplo n.º 1
0
    def sample_uniformly(self,
                         num_samples,
                         min_seq_len=None,
                         max_seq_len=None,
                         pad=True,
                         seed=None):
        """Samples valid integer-encoded sequences from the domain.

    Args:
      num_samples: The number of samples.
      min_seq_len: The minimum sequence length of samples (inclusive).
      max_seq_len: The maximum sequence length of samples (inclusive).
      pad: Whether to pad sequences to the maximum length.
      seed: Optional seed of the random number generator.

    Returns:
      A list with `num_samples` samples.
    """
        if min_seq_len is None:
            min_seq_len = self.min_length
        if max_seq_len is None:
            max_seq_len = self.length
        random_state = utils.get_random_state(seed)
        valid_token_ids = np.delete(self.vocab.token_ids, self.vocab.eos)
        lengths = random_state.randint(min_seq_len, max_seq_len + 1,
                                       num_samples)
        seqs = [
            random_state.choice(valid_token_ids, length) for length in lengths
        ]
        if pad:
            seqs = seq_utils.pad_sequences(seqs, self.length, self.vocab.eos)
        return seqs
Ejemplo n.º 2
0
 def sample_uniformly(self, num_samples, seed=None):
     random_state = utils.get_random_state(seed)
     return np.int32(
         random_state.randint(size=[num_samples, self.length],
                              low=0,
                              high=self.vocab_size))