예제 #1
0
 def testOneHotRound(self):
     self.assertTrue(
         np.allclose(strict_onehot_round(np.array([0.1, 0.5, 0.3])),
                     np.array([0, 1, 0])))
     # One item should be set to one at random.
     self.assertEqual(
         np.count_nonzero(
             np.isclose(
                 randomized_onehot_round(np.array([0.0, 0.0, 0.0])),
                 np.array([1, 1, 1]),
             )),
         1,
     )
예제 #2
0
 def untransform_observation_features(
     self, observation_features: List[ObservationFeatures]
 ) -> List[ObservationFeatures]:
     for obsf in observation_features:
         for p_name in self.encoder.keys():
             x = np.array([
                 obsf.parameters.pop(p)
                 for p in self.encoded_parameters[p_name]
             ])
             if self.rounding == "strict":
                 x = strict_onehot_round(x)
             else:
                 x = randomized_onehot_round(x)
             val = self.encoder[p_name].inverse_transform(
                 encoded_labels=x[None, :])[0]
             if isinstance(val, np.bool_):
                 val = bool(val)  # Numpy bools don't serialize
             obsf.parameters[p_name] = val
     return observation_features