Example #1
0
 def __init__(
     self,
     action_shape,
     stochastic_size=30,
     deterministic_size=200,
     hidden_size=200,
     image_shape=(3, 64, 64),
     action_hidden_size=200,
     action_layers=3,
     action_dist='one_hot',
     reward_shape=(1, ),
     reward_layers=3,
     reward_hidden=300,
     value_shape=(1, ),
     value_layers=3,
     value_hidden=200,
     dtype=torch.float,
     use_pcont=False,
     pcont_layers=3,
     pcont_hidden=200,
     **kwargs,
 ):
     super().__init__()
     self.observation_encoder = ObservationEncoder(shape=image_shape)
     encoder_embed_size = self.observation_encoder.embed_size
     decoder_embed_size = stochastic_size + deterministic_size
     self.observation_decoder = ObservationDecoder(
         embed_size=decoder_embed_size, shape=image_shape)
     self.action_shape = action_shape
     output_size = np.prod(action_shape)
     self.transition = RSSMTransition(output_size, stochastic_size,
                                      deterministic_size, hidden_size)
     self.representation = RSSMRepresentation(self.transition,
                                              encoder_embed_size,
                                              output_size, stochastic_size,
                                              deterministic_size,
                                              hidden_size)
     self.rollout = RSSMRollout(self.representation, self.transition)
     feature_size = stochastic_size + deterministic_size
     self.action_size = output_size
     self.action_dist = action_dist
     self.action_decoder = ActionDecoder(output_size, feature_size,
                                         action_hidden_size, action_layers,
                                         action_dist)
     self.reward_model = DenseModel(feature_size, reward_shape,
                                    reward_layers, reward_hidden)
     self.value_model = DenseModel(feature_size, value_shape, value_layers,
                                   value_hidden)
     self.dtype = dtype
     self.stochastic_size = stochastic_size
     self.deterministic_size = deterministic_size
     if use_pcont:
         self.pcont = DenseModel(feature_size, (1, ),
                                 pcont_layers,
                                 pcont_hidden,
                                 dist='binary')
Example #2
0
def test_observation(shape):
    batch_size = 2
    c, h, w = shape

    encoder = ObservationEncoder(shape=shape)
    decoder = ObservationDecoder(embed_size=encoder.embed_size, shape=shape)

    image_obs = torch.randn(batch_size, c, h, w)

    with torch.no_grad():
        obs_dist: torch.distributions.Normal = decoder(encoder(image_obs))
    obs_sample: torch.Tensor = obs_dist.sample()
    assert obs_sample.size(0) == batch_size
    assert obs_sample.size(1) == c
    assert obs_sample.size(2) == h
    assert obs_sample.size(3) == w

    embedding = torch.randn(batch_size, encoder.embed_size)

    with torch.no_grad():
        embedding: torch.Tensor = encoder(decoder(embedding).sample())
    assert embedding.size(0) == batch_size
    assert embedding.size(1) == encoder.embed_size
Example #3
0
def test_observation_decoder(shape=(3, 64, 64)):
    decoder = ObservationDecoder()

    batch_size = 2
    c, h, w = shape
    embedding = torch.randn(batch_size, 1024)
    with torch.no_grad():
        obs_dist: torch.distributions.Normal = decoder(embedding)
    obs_sample: torch.Tensor = obs_dist.sample()
    assert obs_sample.size(0) == batch_size
    assert obs_sample.size(1) == c
    assert obs_sample.size(2) == h
    assert obs_sample.size(3) == w

    # Test a version where we have 2 batch dimensions
    horizon = 4
    embedding = torch.randn(batch_size, horizon, 1024)
    with torch.no_grad():
        obs_dist: torch.distributions.Normal = decoder(embedding)
    obs_sample: torch.Tensor = obs_dist.sample()
    assert obs_sample.size(0) == batch_size
    assert obs_sample.size(1) == horizon
    assert obs_sample.size(2) == c
    assert obs_sample.size(3) == h
    assert obs_sample.size(4) == w

    # Test a version where we have 2 batch dimensions
    horizon = 4
    embedding = torch.randn(batch_size, horizon, 1024)
    with torch.no_grad():
        obs_dist: torch.distributions.Normal = decoder(embedding)
    obs_sample: torch.Tensor = obs_dist.sample()
    assert obs_sample.size(0) == batch_size
    assert obs_sample.size(1) == horizon
    assert obs_sample.size(2) == c
    assert obs_sample.size(3) == h
    assert obs_sample.size(4) == w