Ejemplo n.º 1
0
def test_logging():

    from main import (
        TemplateSimulatorSession,
        default_config,
    )

    sim = TemplateSimulatorSession(render=False,
                                   log_data=True,
                                   log_file_name="tmp.csv")
    for episode in range(2):
        iteration = 0
        terminal = False
        sim_state = sim.episode_start(config=default_config)
        while not terminal:
            action = policies.random_policy(sim_state)
            sim.episode_step(action)
            sim_state = sim.get_state()
            print(f"Running iteration #{iteration} for episode #{episode}")
            print(f"Observations: {sim_state}")
            sim.log_iterations(sim_state, action, episode, iteration)
            iteration += 1
            terminal = iteration >= 10
    assert sim.render == False
    assert os.path.exists(sim.log_full_path)
    os.remove("logs/tmp.csv")
Ejemplo n.º 2
0
def test_pole_displacement(sim):

    sim_state = sim.get_state()
    random_action = policies.random_policy(sim_state)

    sim.episode_step(random_action)
    next_state = sim.get_state()

    default_displacement = next_state["x_position"] - sim_state["x_position"]

    sim.episode_start(large_config)
    sim_state = sim.get_state()
    random_action = policies.random_policy(sim_state)

    sim.episode_step(random_action)
    next_state = sim.get_state()

    smaller_displacement = next_state["x_position"] - sim_state["x_position"]

    assert abs(smaller_displacement) < abs(default_displacement)
Ejemplo n.º 3
0
def test_random_action(sim):

    sim_state = sim.get_state()
    assert sim_state is not None

    random_action = policies.random_policy(sim_state)
    assert random_action is not None

    sim.episode_step(random_action)

    next_sim_state = sim.get_state()
    assert next_sim_state is not None
Ejemplo n.º 4
0
def test_physics(sim):
    """
    Same force should change velocity of heavier cartpole less
    """
    sim_state = sim.get_state()
    random_action = policies.random_policy(sim_state)

    sim.episode_step(random_action)
    next_state = sim.get_state()
    print(f"sim_state: {sim_state}; next_state: {next_state}")
    default_delta_v = next_state["cart_velocity"] - sim_state["cart_velocity"]

    sim.episode_start(large_config)
    sim_state = sim.get_state()
    # use the same action as above
    sim.episode_step(random_action)
    next_state = sim.get_state()
    print(f"sim_state: {sim_state}; next_state: {next_state}")

    smaller_delta_v = next_state["cart_velocity"] - sim_state["cart_velocity"]

    assert abs(smaller_delta_v) < abs(default_delta_v)
Ejemplo n.º 5
0
    def random_policy(self, state: Dict = None) -> Dict:

        return random_policy(state)
Ejemplo n.º 6
0
def test_policies():
  """Tests for the policies."""
  result = True
  actions = {
      0:
          gt.Action(
              None,
              None,
              action=0,
              visit_count=9,
              average_value=0.1,
              prior=0.2,
              for_maximizer=True),
      1:
          gt.Action(
              None,
              None,
              action=1,
              visit_count=16,
              average_value=0.9,
              prior=0.1,
              for_maximizer=True),
      8:
          gt.Action(
              None,
              None,
              action=8,
              visit_count=25,
              average_value=-0.8,
              prior=0.4,
              for_maximizer=True),
      7:
          gt.Action(
              None,
              None,
              action=7,
              visit_count=50,
              average_value=0.0,
              prior=0.3,
              for_maximizer=True)
  }
  np.testing.assert_almost_equal(
      po.alpha_zero_visit_counts_to_target(actions, 9, 1.0),
      np.array([0.09, 0.16, -1.0, -1.0, -1.0, -1.0, -1.0, 0.50, 0.25]))
  np.testing.assert_almost_equal(
      po.average_values_to_target(actions, 9),
      np.array([1.1, 1.9, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 0.2]))
  result = result and expect_equal(
      po.greedy_value_policy(actions), actions[1], 'greedy value')
  result = result and expect_equal(
      po.greedy_visit_policy(actions), actions[7], 'greedy visit')
  result = result and expect_equal(
      po.greedy_prior_policy(actions), actions[8], 'greedy prior')
  result = result and expect_equal(
      po.alpha_zero_mcts_policy(actions, c_puct=1.0), actions[1],
      'alpha zero mcts')
  for a in actions.iteritems():
    a[1].for_maximizer = False
  result = result and expect_equal(
      po.greedy_value_policy(actions), actions[8], 'greedy value minimizer')
  result = result and expect_equal(
      po.alpha_zero_mcts_policy(actions, c_puct=1.0), actions[8],
      'alpha zero mcts minimizer')
  np.testing.assert_almost_equal(
      po.average_values_to_target(actions, 9),
      np.array([0.9, 0.1, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.8]))
  # Smoke test for the randomized policies.
  po.random_policy(actions)
  po.alpha_zero_play_policy(actions, tau=1.0)
  test_result(result, 'Policies Test')
Ejemplo n.º 7
0
  def next_batch(self):
    """Return the next minibatch of augmented data."""
    next_train_index = self.curr_train_index + self.hparams.batch_size
    if next_train_index > self.num_train:
      # Increase epoch number
      epoch = self.epochs + 1
      self.reset()
      self.epochs = epoch
    batched_data = (
        self.train_images[self.curr_train_index:self.curr_train_index +
                          self.hparams.batch_size],
        self.train_labels[self.curr_train_index:self.curr_train_index +
                          self.hparams.batch_size])
    final_imgs = []
    images, labels = batched_data
    if self.hparams.augment_type == 'mixup':
      images, labels = augmentation_transforms.mixup_batch(
          images, labels, self.hparams.mixup_alpha)
    elif self.hparams.augment_type == 'image_freq':
      images, labels = augmentation_transforms.freq_augment(
          images,
          labels,
          amplitude=self.hparams.freq_augment_amplitude,
          magnitude=self.hparams.augmentation_magnitude,
          proportion_f=self.hparams.freq_augment_ffrac,
          probability=self.hparams.augmentation_probability)
    for data in images:
      if self.hparams.augment_type == 'autoaugment':
        epoch_policy = self.good_policies[np.random.choice(
            len(self.good_policies))]
        final_img = augmentation_transforms.apply_policy(epoch_policy, data)
      elif self.hparams.augment_type == 'random':
        epoch_policy = found_policies.random_policy(
            self.hparams.num_augmentation_layers,
            self.hparams.augmentation_magnitude,
            self.hparams.augmentation_probability)
        final_img = augmentation_transforms.apply_policy(epoch_policy, data)
      else:
        final_img = np.copy(data)
      if self.hparams.apply_flip_crop:
        final_img = augmentation_transforms.random_flip(
            augmentation_transforms.zero_pad_and_crop(data, 4))
      # Apply cutout
      if self.hparams.apply_cutout:
        final_img = augmentation_transforms.cutout_numpy(final_img)

      final_imgs.append(final_img)
    final_imgs = np.array(final_imgs, np.float32)
    if self.hparams.noise_type == 'radial':
      labels = augmentation_transforms.add_radial_noise(
          final_imgs, labels, self.hparams.frequency, self.hparams.amplitude,
          self.hparams.noise_class, self.hparams.normalize_amplitude)
    elif self.hparams.noise_type == 'random' or self.hparams.noise_type == 'fourier' or self.hparams.noise_type == 'f' or self.hparams.noise_type == '1/f':
      labels = augmentation_transforms.add_sinusoidal_noise(
          final_imgs, labels, self.hparams.frequency, self.hparams.amplitude,
          self.direction, self.hparams.noise_class,
          self.hparams.normalize_amplitude)
    elif self.hparams.noise_type == 'uniform':
      labels = augmentation_transforms.add_uniform_noise(
          labels, self.hparams.amplitude, self.hparams.noise_class)

    batched_data = (final_imgs, labels)
    self.curr_train_index += self.hparams.batch_size
    return batched_data