def __init__( self, action_shape, stochastic_size=30, deterministic_size=200, hidden_size=200, image_shape=(3, 64, 64), action_hidden_size=200, action_layers=3, action_dist='one_hot', reward_shape=(1, ), reward_layers=3, reward_hidden=300, value_shape=(1, ), value_layers=3, value_hidden=200, dtype=torch.float, use_pcont=False, pcont_layers=3, pcont_hidden=200, **kwargs, ): super().__init__() self.observation_encoder = ObservationEncoder(shape=image_shape) encoder_embed_size = self.observation_encoder.embed_size decoder_embed_size = stochastic_size + deterministic_size self.observation_decoder = ObservationDecoder( embed_size=decoder_embed_size, shape=image_shape) self.action_shape = action_shape output_size = np.prod(action_shape) self.transition = RSSMTransition(output_size, stochastic_size, deterministic_size, hidden_size) self.representation = RSSMRepresentation(self.transition, encoder_embed_size, output_size, stochastic_size, deterministic_size, hidden_size) self.rollout = RSSMRollout(self.representation, self.transition) feature_size = stochastic_size + deterministic_size self.action_size = output_size self.action_dist = action_dist self.action_decoder = ActionDecoder(output_size, feature_size, action_hidden_size, action_layers, action_dist) self.reward_model = DenseModel(feature_size, reward_shape, reward_layers, reward_hidden) self.value_model = DenseModel(feature_size, value_shape, value_layers, value_hidden) self.dtype = dtype self.stochastic_size = stochastic_size self.deterministic_size = deterministic_size if use_pcont: self.pcont = DenseModel(feature_size, (1, ), pcont_layers, pcont_hidden, dist='binary')
def test_observation(shape): batch_size = 2 c, h, w = shape encoder = ObservationEncoder(shape=shape) decoder = ObservationDecoder(embed_size=encoder.embed_size, shape=shape) image_obs = torch.randn(batch_size, c, h, w) with torch.no_grad(): obs_dist: torch.distributions.Normal = decoder(encoder(image_obs)) obs_sample: torch.Tensor = obs_dist.sample() assert obs_sample.size(0) == batch_size assert obs_sample.size(1) == c assert obs_sample.size(2) == h assert obs_sample.size(3) == w embedding = torch.randn(batch_size, encoder.embed_size) with torch.no_grad(): embedding: torch.Tensor = encoder(decoder(embedding).sample()) assert embedding.size(0) == batch_size assert embedding.size(1) == encoder.embed_size
def test_observation_decoder(shape=(3, 64, 64)): decoder = ObservationDecoder() batch_size = 2 c, h, w = shape embedding = torch.randn(batch_size, 1024) with torch.no_grad(): obs_dist: torch.distributions.Normal = decoder(embedding) obs_sample: torch.Tensor = obs_dist.sample() assert obs_sample.size(0) == batch_size assert obs_sample.size(1) == c assert obs_sample.size(2) == h assert obs_sample.size(3) == w # Test a version where we have 2 batch dimensions horizon = 4 embedding = torch.randn(batch_size, horizon, 1024) with torch.no_grad(): obs_dist: torch.distributions.Normal = decoder(embedding) obs_sample: torch.Tensor = obs_dist.sample() assert obs_sample.size(0) == batch_size assert obs_sample.size(1) == horizon assert obs_sample.size(2) == c assert obs_sample.size(3) == h assert obs_sample.size(4) == w # Test a version where we have 2 batch dimensions horizon = 4 embedding = torch.randn(batch_size, horizon, 1024) with torch.no_grad(): obs_dist: torch.distributions.Normal = decoder(embedding) obs_sample: torch.Tensor = obs_dist.sample() assert obs_sample.size(0) == batch_size assert obs_sample.size(1) == horizon assert obs_sample.size(2) == c assert obs_sample.size(3) == h assert obs_sample.size(4) == w