def test_type(self): self.assertIsInstance( math.lognormexp(torch.rand(1)), torch.Tensor ) self.assertIsInstance( math.lognormexp(np.array([2])), np.ndarray )
def test_value(self): test_input = [1, 2, 3] temp = np.exp(1) + np.exp(2) + np.exp(3) test_result = np.log(np.exp(test_input) / temp) np.testing.assert_allclose( math.lognormexp(torch.Tensor(test_input)).numpy(), torch.Tensor(test_result).numpy(), atol=1e-6 ) np.testing.assert_allclose( math.lognormexp(np.array(test_input)), np.array(test_result), atol=1e-6 )
def encode_particles(self, latent_state): """ RNN that encodes the set of particles into one latent vector that can be passed to policy. """ batch_size, num_particles, h_dim = latent_state.h.size() state = torch.cat([latent_state.h, latent_state.phi_z], dim=2) # latent_state.h [batch, particles, h_dim?] normalized_log_weights = math.lognormexp( # inference_result.log_weights[-1], latent_state.log_weight, dim=1) particle_state = torch.cat( [state, torch.exp(normalized_log_weights).unsqueeze(-1)], dim=2) if self.num_particles == 1: # Get rid of particle dimension, particle_gru is just a nn.Linear particle_state = particle_state.squeeze(1) encoded_particles = self.particle_gru(particle_state) # encoded_particles = self.particle_gru_bn(encoded_particles) return encoded_particles else: _, encoded_particles = self.particle_gru(particle_state) # encoded_particles [num_layers * num_directions, batch, h_dim] # First dimension: num_layers * num_directions # Dimension of Output? return encoded_particles[0]
def test_dimensions(self): self.assertEqual( math.lognormexp( torch.rand(2, 3, 4, 5), dim=2 ).size(), torch.Size([2, 3, 4, 5]) ) self.assertEqual( math.lognormexp(torch.rand(3)).size(), torch.Size([3]) ) self.assertEqual( math.lognormexp(torch.rand(1)).size(), torch.Size([1]) ) self.assertEqual( list(np.shape(math.lognormexp( np.random.rand(2, 3, 4, 5), dim=2 ))), [2, 3, 4, 5] ) self.assertEqual( list(np.shape(math.lognormexp(np.random.rand(3)))), [3] ) self.assertEqual( list(np.shape(math.lognormexp(np.random.rand(1)))), [1] )