def test_rollouts(): action_size = 10 obs_embed_size = 100 stochastic_size = 30 deterministic_size = 200 batch_size = 4 time_steps = 10 transition_model = RSSMTransition(action_size, stochastic_size, deterministic_size) representation_model = RSSMRepresentation(transition_model, obs_embed_size, action_size, stochastic_size, deterministic_size) rollout_module = RSSMRollout(representation_model, transition_model) obs_embed: torch.Tensor = torch.randn(time_steps, batch_size, obs_embed_size) action: torch.Tensor = torch.randn(time_steps, batch_size, action_size) prev_state: RSSMState = representation_model.initial_state(batch_size) prior, post = rollout_module(time_steps, obs_embed, action, prev_state) assert isinstance(prior, RSSMState) assert isinstance(post, RSSMState) assert prior.mean.shape == (time_steps, batch_size, stochastic_size) assert post.mean.shape == (time_steps, batch_size, stochastic_size) assert prior.std.shape == (time_steps, batch_size, stochastic_size) assert post.std.shape == (time_steps, batch_size, stochastic_size) assert prior.stoch.shape == (time_steps, batch_size, stochastic_size) assert post.stoch.shape == (time_steps, batch_size, stochastic_size) assert prior.deter.shape == (time_steps, batch_size, deterministic_size) assert post.deter.shape == (time_steps, batch_size, deterministic_size) prior = rollout_module.rollout_transition( time_steps, action, transition_model.initial_state(batch_size)) assert isinstance(prior, RSSMState) assert prior.mean.shape == (time_steps, batch_size, stochastic_size) assert prior.std.shape == (time_steps, batch_size, stochastic_size) assert prior.stoch.shape == (time_steps, batch_size, stochastic_size) assert prior.deter.shape == (time_steps, batch_size, deterministic_size) def policy(state): action = torch.randn(state.stoch.size(0), action_size) mean = torch.randn(state.stoch.size(0), action_size) std = torch.randn(state.stoch.size(0), action_size) action_dist = SampleDist(torch.distributions.Normal(mean, std)) return action, action_dist prior, actions = rollout_module.rollout_policy(time_steps, policy, post[-1]) assert isinstance(prior, RSSMState) assert prior.mean.shape == (time_steps, batch_size, stochastic_size) assert prior.std.shape == (time_steps, batch_size, stochastic_size) assert prior.stoch.shape == (time_steps, batch_size, stochastic_size) assert prior.deter.shape == (time_steps, batch_size, deterministic_size) assert isinstance(actions, torch.Tensor) assert actions.shape == (time_steps, batch_size, action_size)
def __init__( self, action_shape, stochastic_size=30, deterministic_size=200, hidden_size=200, image_shape=(3, 64, 64), action_hidden_size=200, action_layers=3, action_dist='one_hot', reward_shape=(1, ), reward_layers=3, reward_hidden=300, value_shape=(1, ), value_layers=3, value_hidden=200, dtype=torch.float, use_pcont=False, pcont_layers=3, pcont_hidden=200, **kwargs, ): super().__init__() self.observation_encoder = ObservationEncoder(shape=image_shape) encoder_embed_size = self.observation_encoder.embed_size decoder_embed_size = stochastic_size + deterministic_size self.observation_decoder = ObservationDecoder( embed_size=decoder_embed_size, shape=image_shape) self.action_shape = action_shape output_size = np.prod(action_shape) self.transition = RSSMTransition(output_size, stochastic_size, deterministic_size, hidden_size) self.representation = RSSMRepresentation(self.transition, encoder_embed_size, output_size, stochastic_size, deterministic_size, hidden_size) self.rollout = RSSMRollout(self.representation, self.transition) feature_size = stochastic_size + deterministic_size self.action_size = output_size self.action_dist = action_dist self.action_decoder = ActionDecoder(output_size, feature_size, action_hidden_size, action_layers, action_dist) self.reward_model = DenseModel(feature_size, reward_shape, reward_layers, reward_hidden) self.value_model = DenseModel(feature_size, value_shape, value_layers, value_hidden) self.dtype = dtype self.stochastic_size = stochastic_size self.deterministic_size = deterministic_size if use_pcont: self.pcont = DenseModel(feature_size, (1, ), pcont_layers, pcont_hidden, dist='binary')
def __init__( self, action_shape, stochastic_size=30, deterministic_size=200, hidden_size=200, image_shape=(3, 64, 64), stride=2, depth=32, padding=0, action_hidden_size=200, action_layers=3, action_dist='one_hot', reward_shape=(1, ), reward_layers=3, reward_hidden=300, value_shape=(1, ), value_layers=3, value_hidden=200, dtype=torch.float, state_size=None, use_pcont=False, pcont_layers=3, pcont_hidden=200, full_conv=True, **kwargs, ): super().__init__() if full_conv: EncoderClass = ObservationEncoder DecoderClass = ObservationDecoder else: EncoderClass = MiniObservationEncoder DecoderClass = MiniObservationDecoder self.observation_encoder = EncoderClass(shape=image_shape, stride=stride, depth=depth, padding=padding) encoder_embed_size = self.observation_encoder.embed_size if state_size is not None: encoder_embed_size += state_size decoder_embed_size = stochastic_size + deterministic_size self.observation_decoder = DecoderClass(embed_size=decoder_embed_size, shape=image_shape, stride=stride) self.action_shape = action_shape output_size = np.prod(action_shape) self.transition = RSSMTransition(output_size, stochastic_size, deterministic_size, hidden_size) self.representation = RSSMRepresentation(self.transition, encoder_embed_size, output_size, stochastic_size, deterministic_size, hidden_size) self.rollout = RSSMRollout(self.representation, self.transition) feature_size = stochastic_size + deterministic_size self.action_size = output_size self.action_dist = action_dist self.action_decoder = ActionDecoder(output_size, feature_size, action_hidden_size, action_layers, action_dist) self.reward_model = DenseModel(feature_size, reward_shape, reward_layers, reward_hidden) self.value_model = DenseModel(feature_size, value_shape, value_layers, value_hidden) self.dtype = dtype self.stochastic_size = stochastic_size self.deterministic_size = deterministic_size self.use_state = state_size is not None self.state_size = state_size if use_pcont: self.pcont = DenseModel(feature_size, (1, ), pcont_layers, pcont_hidden, dist='binary')