Exemple #1
0
  def test_needs_probabilities(self):
    randomizer = ur.UniqueRandomizer()

    self.assertTrue(randomizer.needs_probabilities())
    first_index = randomizer.sample_distribution([0.9, 0.1])
    self.assertTrue(randomizer.needs_probabilities())
    randomizer.mark_sequence_complete()

    self.assertFalse(randomizer.needs_probabilities())
    second_index = randomizer.sample_distribution(None)
    self.assertTrue(randomizer.needs_probabilities())
    randomizer.sample_boolean(probability_1=0.123)
    self.assertTrue(randomizer.needs_probabilities())
    randomizer.mark_sequence_complete()

    self.assertNotEqual(first_index, second_index)

    self.assertFalse(randomizer.needs_probabilities())
    randomizer.sample_distribution(None)
    self.assertFalse(randomizer.needs_probabilities())
    randomizer.sample_boolean(probability_1=999)
    self.assertTrue(randomizer.needs_probabilities())
    randomizer.mark_sequence_complete()

    self.assertTrue(randomizer.exhausted())
Exemple #2
0
  def test_proportions(self):
    # It's possible but extremely unlikely for this test to fail.
    results = []
    for _ in range(10000):
      randomizer = ur.UniqueRandomizer()
      digits = []
      while not randomizer.exhausted():
        # Choose 2 with probability 0.6, or 0 or 1 with probability 0.2 each.
        if randomizer.sample_boolean(probability_1=0.6):
          digits.append('2')
        else:
          digits.append(str(randomizer.sample_uniform(2)))
        randomizer.mark_sequence_complete()
      results.append(''.join(digits))

    self.assertTrue(all(len(s) == 3 for s in results))
    counter = collections.Counter(results)

    # P('201') = P('210') = 0.6 * 0.5.
    self.assertAlmostEqual(counter['201'], 0.6 * 0.5 * 10000, delta=250)
    self.assertAlmostEqual(counter['210'], 0.6 * 0.5 * 10000, delta=250)

    # P('021') = P('120') = 0.2 * 0.75.
    self.assertAlmostEqual(counter['021'], 0.2 * 0.75 * 10000, delta=200)
    self.assertAlmostEqual(counter['120'], 0.2 * 0.75 * 10000, delta=200)

    # P('012') = P('102') = 0.2 * 0.25.
    self.assertAlmostEqual(counter['012'], 0.2 * 0.25 * 10000, delta=100)
    self.assertAlmostEqual(counter['102'], 0.2 * 0.25 * 10000, delta=100)
def sample_with_ur(
        num_unique_samples: int,
        nonterminal_name: Text,
        grammar_dict: Dict[Text, Nonterminal],
        prob_sleep_millis: float = 0.0) -> Tuple[Dict[Text, float], int]:
    """Samples unique expansions of a nonterminal with UniqueRandomizer."""
    return _sample_expansions(num_unique_samples,
                              nonterminal_name, grammar_dict,
                              ur.UniqueRandomizer(), prob_sleep_millis)
Exemple #4
0
def sample_flips_without_replacement() -> None:
  """Samples the coin flips without replacement, printing out the results."""
  randomizer = ur.UniqueRandomizer()

  # Sample pairs of coin flips until all possible results have been sampled.
  while not randomizer.exhausted():
    sample = flip_two_weighted_coins(randomizer)
    log_probability = randomizer.mark_sequence_complete()

    print('Sample {} is {} with probability {:2.0f}%. '
          'In total, {:3.0f}% of the output space has been sampled.'.format(
              randomizer.num_sequences_sampled(),
              sample,
              math.exp(log_probability) * 100,
              randomizer.fraction_sampled() * 100))
Exemple #5
0
  def test_root_is_leaf_edge_case(self):
    randomizer = ur.UniqueRandomizer()

    self.assertEqual(randomizer.fraction_sampled(), 0.0)
    self.assertFalse(randomizer.exhausted())
    self.assertEqual(randomizer.num_sequences_sampled(), 0)

    log_probability = randomizer.mark_sequence_complete()

    self.assertEqual(log_probability, math.log(1.0))
    self.assertEqual(randomizer.fraction_sampled(), 1.0)
    self.assertTrue(randomizer.exhausted())
    self.assertEqual(randomizer.num_sequences_sampled(), 1)

    with self.assertRaises(ur.AllSequencesSampledError):
      randomizer.sample_boolean(0.1)
def farthest_insertion_sampling(nodes, num_samples, unique_samples,
                                temperature, caching=True):
  """Samples using the farthest-insertion heuristic."""
  min_cost, best_tour = farthest_insertion(nodes, randomizer=None)

  randomizer = (ur.UniqueRandomizer() if unique_samples
                else ur.NormalRandomizer())
  for _ in range(1, num_samples):
    cost, tour = farthest_insertion(nodes, randomizer, temperature,
                                    caching=caching)
    randomizer.mark_sequence_complete()
    if cost < min_cost:
      min_cost = cost
      best_tour = tour

  return min_cost, best_tour
def sample_with_unique_randomizer_batched(model: FakeSequenceModel,
                                          num_samples: int,
                                          batch_size: int) -> List[List[int]]:
    """Samples using UniqueRandomizer with SBS for batching."""
    randomizer = ur.UniqueRandomizer()
    outputs = []
    while not randomizer.exhausted() and len(outputs) < num_samples:
        this_batch_size = min(batch_size, num_samples - len(outputs))
        beam_nodes = randomizer.sample_batch(
            child_log_probability_fn=functools.partial(
                _child_log_probability_fn, model=model),
            child_state_fn=functools.partial(_child_state_fn, model=model),
            root_state=[],
            k=this_batch_size)
        outputs.extend(node.output for node in beam_nodes)
    return outputs
def sample_with_unique_randomizer(model: FakeSequenceModel,
                                  num_samples: int) -> List[List[int]]:
    """Samples using the UniqueRandomizer."""
    randomizer = ur.UniqueRandomizer()
    samples = []

    while len(samples) < num_samples and not randomizer.exhausted():
        # Create a sample. These are guaranteed to be unique.
        prefix = []
        while not model.sequence_complete(prefix):
            distribution = np.exp(model.next_token_log_probabilities(prefix))
            next_token = randomizer.sample_distribution(distribution)
            prefix.append(next_token)
        randomizer.mark_sequence_complete()
        samples.append(prefix)

    return samples
Exemple #9
0
  def test_ur_sbs_proportions(self, k):
    # This test is analogous to test_proportions.
    # It's possible but extremely unlikely for this test to fail.

    # A state is a pair representing two coin flips. The first flip is biased
    # (60% True). If True, the output is '2'. If False, then the second flip
    # (fair odds) determines whether the output is '0' or '1'.
    # See unique_randomizer_test.py's test_proportions for a procedural
    # representation of this logic.

    def child_log_probability_fn(states):
      results = []
      for state in states:
        first_flip, _ = state
        if first_flip is None:
          results.append(np.log([0.4, 0.6]))
        elif not first_flip:
          results.append(np.log([0.5, 0.5]))
        else:
          raise ValueError('Leaf state encountered unexpectedly.')
      return results

    def child_state_fn(state_index_pairs):
      results = []
      for (first_flip, _), index in state_index_pairs:
        if first_flip is None:
          if index == 0:
            child_state = (False, None)
            results.append((child_state, False))
          elif index == 1:
            output = '2'
            results.append((output, True))
          else:
            raise ValueError('Out of bounds index: {}'.format(index))
        elif not first_flip:
          output = str(index)
          results.append((output, True))
        else:
          raise ValueError('Leaf state encountered unexpectedly.')
      return results

    results = []
    for _ in range(10000):
      randomizer = ur.UniqueRandomizer()
      digit_results = []
      while not randomizer.exhausted():
        beam_nodes = randomizer.sample_batch(
            child_log_probability_fn=child_log_probability_fn,
            child_state_fn=child_state_fn,
            root_state=(None, None),
            k=k)
        digit_results.extend([node.output for node in beam_nodes])
      results.append(''.join(digit_results))

    self.assertTrue(all(len(s) == 3 for s in results))
    counter = collections.Counter(results)

    # P('201') = P('210') = 0.6 * 0.5.
    self.assertAlmostEqual(counter['201'], 0.6 * 0.5 * 10000, delta=250)
    self.assertAlmostEqual(counter['210'], 0.6 * 0.5 * 10000, delta=250)

    # P('021') = P('120') = 0.2 * 0.75.
    self.assertAlmostEqual(counter['021'], 0.2 * 0.75 * 10000, delta=200)
    self.assertAlmostEqual(counter['120'], 0.2 * 0.75 * 10000, delta=200)

    # P('012') = P('102') = 0.2 * 0.25.
    self.assertAlmostEqual(counter['012'], 0.2 * 0.25 * 10000, delta=100)
    self.assertAlmostEqual(counter['102'], 0.2 * 0.25 * 10000, delta=100)