def __init__(self, observation_space: spaces.Box, features_dim: int = 512): super(NatureCNN, self).__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper assert is_image_space(observation_space), ( "You should use NatureCNN " f"only with images not with {observation_space}\n" "(you are probably using `CnnPolicy` instead of `MlpPolicy`)\n" "If you are using a custom environment,\n" "please check it using our env checker:\n" "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html" ) n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential( nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Flatten(), ) # Compute shape by doing one forward pass with th.no_grad(): n_flatten = self.cnn( th.as_tensor( observation_space.sample()[None]).float()).shape[1] self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
class FakeImageEnv(Env): """ Fake image environment for testing purposes, it mimics Atari games. :param action_dim: Number of discrete actions :param screen_height: Height of the image :param screen_width: Width of the image :param n_channels: Number of color channels :param discrete: Create discrete action space instead of continuous :param channel_first: Put channels on first axis instead of last """ def __init__( self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True, channel_first: bool = False, ): self.observation_shape = (screen_height, screen_width, n_channels) if channel_first: self.observation_shape = (n_channels, screen_height, screen_width) self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) if discrete: self.action_space = Discrete(action_dim) else: self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32) self.ep_length = 10 self.current_step = 0 def reset(self) -> np.ndarray: self.current_step = 0 return self.observation_space.sample() def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: reward = 0.0 self.current_step += 1 done = self.current_step >= self.ep_length return self.observation_space.sample(), reward, done, {} def render(self, mode: str = "human") -> None: pass