예제 #1
0
def test_rollouts():
    action_size = 10
    obs_embed_size = 100
    stochastic_size = 30
    deterministic_size = 200
    batch_size = 4
    time_steps = 10

    transition_model = RSSMTransition(action_size, stochastic_size,
                                      deterministic_size)
    representation_model = RSSMRepresentation(transition_model, obs_embed_size,
                                              action_size, stochastic_size,
                                              deterministic_size)

    rollout_module = RSSMRollout(representation_model, transition_model)

    obs_embed: torch.Tensor = torch.randn(time_steps, batch_size,
                                          obs_embed_size)
    action: torch.Tensor = torch.randn(time_steps, batch_size, action_size)
    prev_state: RSSMState = representation_model.initial_state(batch_size)

    prior, post = rollout_module(time_steps, obs_embed, action, prev_state)

    assert isinstance(prior, RSSMState)
    assert isinstance(post, RSSMState)
    assert prior.mean.shape == (time_steps, batch_size, stochastic_size)
    assert post.mean.shape == (time_steps, batch_size, stochastic_size)
    assert prior.std.shape == (time_steps, batch_size, stochastic_size)
    assert post.std.shape == (time_steps, batch_size, stochastic_size)
    assert prior.stoch.shape == (time_steps, batch_size, stochastic_size)
    assert post.stoch.shape == (time_steps, batch_size, stochastic_size)
    assert prior.deter.shape == (time_steps, batch_size, deterministic_size)
    assert post.deter.shape == (time_steps, batch_size, deterministic_size)

    prior = rollout_module.rollout_transition(
        time_steps, action, transition_model.initial_state(batch_size))
    assert isinstance(prior, RSSMState)
    assert prior.mean.shape == (time_steps, batch_size, stochastic_size)
    assert prior.std.shape == (time_steps, batch_size, stochastic_size)
    assert prior.stoch.shape == (time_steps, batch_size, stochastic_size)
    assert prior.deter.shape == (time_steps, batch_size, deterministic_size)

    def policy(state):
        action = torch.randn(state.stoch.size(0), action_size)
        mean = torch.randn(state.stoch.size(0), action_size)
        std = torch.randn(state.stoch.size(0), action_size)
        action_dist = SampleDist(torch.distributions.Normal(mean, std))
        return action, action_dist

    prior, actions = rollout_module.rollout_policy(time_steps, policy,
                                                   post[-1])
    assert isinstance(prior, RSSMState)
    assert prior.mean.shape == (time_steps, batch_size, stochastic_size)
    assert prior.std.shape == (time_steps, batch_size, stochastic_size)
    assert prior.stoch.shape == (time_steps, batch_size, stochastic_size)
    assert prior.deter.shape == (time_steps, batch_size, deterministic_size)

    assert isinstance(actions, torch.Tensor)
    assert actions.shape == (time_steps, batch_size, action_size)
예제 #2
0
 def __init__(
     self,
     action_shape,
     stochastic_size=30,
     deterministic_size=200,
     hidden_size=200,
     image_shape=(3, 64, 64),
     action_hidden_size=200,
     action_layers=3,
     action_dist='one_hot',
     reward_shape=(1, ),
     reward_layers=3,
     reward_hidden=300,
     value_shape=(1, ),
     value_layers=3,
     value_hidden=200,
     dtype=torch.float,
     use_pcont=False,
     pcont_layers=3,
     pcont_hidden=200,
     **kwargs,
 ):
     super().__init__()
     self.observation_encoder = ObservationEncoder(shape=image_shape)
     encoder_embed_size = self.observation_encoder.embed_size
     decoder_embed_size = stochastic_size + deterministic_size
     self.observation_decoder = ObservationDecoder(
         embed_size=decoder_embed_size, shape=image_shape)
     self.action_shape = action_shape
     output_size = np.prod(action_shape)
     self.transition = RSSMTransition(output_size, stochastic_size,
                                      deterministic_size, hidden_size)
     self.representation = RSSMRepresentation(self.transition,
                                              encoder_embed_size,
                                              output_size, stochastic_size,
                                              deterministic_size,
                                              hidden_size)
     self.rollout = RSSMRollout(self.representation, self.transition)
     feature_size = stochastic_size + deterministic_size
     self.action_size = output_size
     self.action_dist = action_dist
     self.action_decoder = ActionDecoder(output_size, feature_size,
                                         action_hidden_size, action_layers,
                                         action_dist)
     self.reward_model = DenseModel(feature_size, reward_shape,
                                    reward_layers, reward_hidden)
     self.value_model = DenseModel(feature_size, value_shape, value_layers,
                                   value_hidden)
     self.dtype = dtype
     self.stochastic_size = stochastic_size
     self.deterministic_size = deterministic_size
     if use_pcont:
         self.pcont = DenseModel(feature_size, (1, ),
                                 pcont_layers,
                                 pcont_hidden,
                                 dist='binary')
예제 #3
0
 def __init__(
     self,
     action_shape,
     stochastic_size=30,
     deterministic_size=200,
     hidden_size=200,
     image_shape=(3, 64, 64),
     stride=2,
     depth=32,
     padding=0,
     action_hidden_size=200,
     action_layers=3,
     action_dist='one_hot',
     reward_shape=(1, ),
     reward_layers=3,
     reward_hidden=300,
     value_shape=(1, ),
     value_layers=3,
     value_hidden=200,
     dtype=torch.float,
     state_size=None,
     use_pcont=False,
     pcont_layers=3,
     pcont_hidden=200,
     full_conv=True,
     **kwargs,
 ):
     super().__init__()
     if full_conv:
         EncoderClass = ObservationEncoder
         DecoderClass = ObservationDecoder
     else:
         EncoderClass = MiniObservationEncoder
         DecoderClass = MiniObservationDecoder
     self.observation_encoder = EncoderClass(shape=image_shape,
                                             stride=stride,
                                             depth=depth,
                                             padding=padding)
     encoder_embed_size = self.observation_encoder.embed_size
     if state_size is not None:
         encoder_embed_size += state_size
     decoder_embed_size = stochastic_size + deterministic_size
     self.observation_decoder = DecoderClass(embed_size=decoder_embed_size,
                                             shape=image_shape,
                                             stride=stride)
     self.action_shape = action_shape
     output_size = np.prod(action_shape)
     self.transition = RSSMTransition(output_size, stochastic_size,
                                      deterministic_size, hidden_size)
     self.representation = RSSMRepresentation(self.transition,
                                              encoder_embed_size,
                                              output_size, stochastic_size,
                                              deterministic_size,
                                              hidden_size)
     self.rollout = RSSMRollout(self.representation, self.transition)
     feature_size = stochastic_size + deterministic_size
     self.action_size = output_size
     self.action_dist = action_dist
     self.action_decoder = ActionDecoder(output_size, feature_size,
                                         action_hidden_size, action_layers,
                                         action_dist)
     self.reward_model = DenseModel(feature_size, reward_shape,
                                    reward_layers, reward_hidden)
     self.value_model = DenseModel(feature_size, value_shape, value_layers,
                                   value_hidden)
     self.dtype = dtype
     self.stochastic_size = stochastic_size
     self.deterministic_size = deterministic_size
     self.use_state = state_size is not None
     self.state_size = state_size
     if use_pcont:
         self.pcont = DenseModel(feature_size, (1, ),
                                 pcont_layers,
                                 pcont_hidden,
                                 dist='binary')