def test_passive_environment_without_pretend_to_be_active():
    """ Test the gym.Env-style interaction with a PassiveEnvironment.
    """
    batch_size = 5
    transforms = Compose([Transforms.to_tensor, Transforms.three_channels])
    dataset = MNIST("data",
                    transform=Compose(
                        [Transforms.to_tensor, Transforms.three_channels]))
    max_samples = 100
    dataset = Subset(dataset, list(range(max_samples)))

    obs_space = Image(0, 255, (1, 28, 28), np.uint8)
    obs_space = transforms(obs_space)
    env = PassiveEnvironment(
        dataset,
        n_classes=10,
        batch_size=batch_size,
        observation_space=obs_space,
        pretend_to_be_active=False,
    )
    assert env.observation_space == Image(0, 1, (batch_size, 3, 28, 28))
    assert env.action_space.shape == (batch_size, )
    assert env.reward_space == env.action_space
    env.seed(123)
    obs = env.reset()
    assert obs in env.observation_space

    obs, reward, done, info = env.step(env.action_space.sample())
    assert reward is not None

    for i, (obs, reward) in enumerate(env):
        assert reward is not None
        other_reward = env.send(env.action_space.sample())
        assert (other_reward == reward).all()
    assert i == max_samples // batch_size - 1
Пример #2
0
    def additional_transforms(self, stage_transforms: List[Transforms]) -> Compose:
        """ Returns the transforms in `stage_transforms` that are additional transforms
        from those in `self.transforms`.

        For example, if:
        ```
        setting.transforms = Compose([Transforms.Resize(32), Transforms.ToTensor])
        setting.train_transforms = Compose([Transforms.Resize(32), Transforms.ToTensor, Transforms.RandomGrayscale])
        ```
        Then:
        ```
        setting.additional_transforms(setting.train_transforms)
        # will give:
        Compose([Transforms.RandomGrayscale])
        ```
        """
        reference_transforms = self.transforms

        if len(stage_transforms) < len(reference_transforms):
            # Assume no overlap, return all the 'stage' transforms.
            return Compose(stage_transforms)
        if stage_transforms == reference_transforms:
            # Complete overlap, return an empty list.
            return Compose([])

        # IDEA: Only add the additional transforms, compared to the 'base' transforms.
        # As soon as one is different, break.
        i = 0
        for i, t_a, t_b in enumerate(zip(stage_transforms, self.transforms)):
            if t_a != t_b:
                break
        return Compose(stage_transforms[i:])
Пример #3
0
def test_observation_wrapper_applied_to_passive_environment():
    """ Test that when we apply a gym wrapper to a PassiveEnvironment, it also
    affects the observations / actions / rewards produced when iterating on the
    env.
    """
    batch_size = 5

    transforms = Compose([Transforms.to_tensor, Transforms.three_channels])
    dataset = MNIST("data", transform=transforms)
    obs_space = Image(0, 255, (1, 28, 28), np.uint8)
    obs_space = transforms(obs_space)
    dataset.classes
    env = PassiveEnvironment(
        dataset, n_classes=10, batch_size=batch_size, observation_space=obs_space,
    )

    assert env.observation_space == Image(0, 1, (batch_size, 3, 28, 28))
    assert env.action_space.shape == (batch_size,)
    assert env.reward_space == env.action_space

    env.seed(123)

    check_env(env)

    # Apply a transformation that changes the observation space.
    env = TransformObservation(env=env, f=Compose([Transforms.resize_64x64]))
    assert env.observation_space == Image(0, 1, (batch_size, 3, 64, 64))
    assert env.action_space.shape == (batch_size,)
    assert env.reward_space.shape == (batch_size,)

    env.seed(123)
    check_env(env)

    env.close()
Пример #4
0
    def test_multiple_epochs_dataloader(self):
        """ Test that we can iterate on the dataloader more than once. """
        max_epochs = 3
        max_samples = 200
        batch_size = 5
        max_batches = max_samples // batch_size
        dataset = MNIST("data",
                        transform=Compose(
                            [Transforms.to_tensor, Transforms.three_channels]))
        dataset = Subset(dataset, list(range(max_samples)))

        env = self.PassiveEnvironment(dataset,
                                      n_classes=10,
                                      batch_size=batch_size)

        assert env.observation_space.shape == (batch_size, 3, 28, 28)
        assert env.action_space.shape == (batch_size, )
        assert env.reward_space.shape == (batch_size, )
        total_steps = 0
        for epoch in range(max_epochs):
            for obs, reward in env:
                assert obs.shape == (batch_size, 3, 28, 28)
                assert reward.shape == (batch_size, )
                total_steps += 1

        assert total_steps == max_batches * max_epochs
def test_env_gives_done_on_last_item():
    # from continuum.datasets import MNIST
    max_samples = 100
    batch_size = 1
    dataset = MNIST("data",
                    transform=Compose(
                        [Transforms.to_tensor, Transforms.three_channels]))
    dataset = Subset(dataset, list(range(max_samples)))

    env = PassiveEnvironment(dataset, n_classes=10, batch_size=batch_size)

    assert env.observation_space.shape == (batch_size, 3, 28, 28)
    assert env.action_space.shape == (batch_size, )
    assert env.reward_space.shape == (batch_size, )

    env.seed(123)
    obs = env.reset()
    assert obs.shape == (batch_size, 3, 28, 28)
    # Starting at 1 since reset() gives one observation already.
    for i in range(1, max_samples):
        obs, reward, done, info = env.step(env.action_space.sample())
        assert obs.shape == (batch_size, 3, 28, 28)
        assert reward.shape == (batch_size, )
        assert done == (i == max_samples - 1), i
        if done:
            break
    else:
        assert False, "Should have reached done=True!"
    assert i == max_samples - 1
    env.close()
Пример #6
0
def test_is_proxy_to(use_wrapper: bool):
    import numpy as np
    from sequoia.common.transforms import Compose, Transforms

    transforms = Compose([Transforms.to_tensor, Transforms.three_channels])
    from sequoia.common.spaces import Image
    from torchvision.datasets import MNIST

    batch_size = 32
    dataset = MNIST("data", transform=transforms)
    obs_space = Image(0, 255, (1, 28, 28), np.uint8)
    obs_space = transforms(obs_space)

    env_type = ProxyPassiveEnvironment if use_wrapper else PassiveEnvironment
    env: Iterable[Tuple[Tensor, Tensor]] = env_type(
        dataset,
        batch_size=batch_size,
        n_classes=10,
        observation_space=obs_space,
    )
    if use_wrapper:
        assert isinstance(env, EnvironmentProxy)
        assert issubclass(type(env), EnvironmentProxy)
        assert is_proxy_to(env, PassiveEnvironment)
    else:
        assert not is_proxy_to(env, PassiveEnvironment)
Пример #7
0
    def test_mnist_as_gym_env(self):
        # from continuum.datasets import MNIST
        dataset = MNIST("data",
                        transform=Compose(
                            [Transforms.to_tensor, Transforms.three_channels]))

        batch_size = 4
        env = self.PassiveEnvironment(dataset,
                                      n_classes=10,
                                      batch_size=batch_size)

        assert env.observation_space.shape == (batch_size, 3, 28, 28)
        assert env.action_space.shape == (batch_size, )
        assert env.reward_space.shape == (batch_size, )

        env.seed(123)
        obs = env.reset()
        assert obs.shape == (batch_size, 3, 28, 28)

        for i in range(10):
            obs, reward, done, info = env.step(env.action_space.sample())
            assert obs.shape == (batch_size, 3, 28, 28)
            assert reward.shape == (batch_size, )
            assert not done
        env.close()
def test_multiple_epochs_dataloader_with_split_batch_fn():
    """ Test that we can iterate on the dataloader more than once. """
    max_epochs = 3
    max_samples = 200
    batch_size = 5

    def split_batch_fn(batch):
        x, y, = batch
        # some dummy function.
        return torch.zeros_like(x), y

    max_batches = max_samples // batch_size
    dataset = MNIST("data",
                    transform=Compose(
                        [Transforms.to_tensor, Transforms.three_channels]))
    dataset = Subset(dataset, list(range(max_samples)))

    env = PassiveEnvironment(dataset,
                             n_classes=10,
                             batch_size=batch_size,
                             split_batch_fn=split_batch_fn)

    assert env.observation_space.shape == (batch_size, 3, 28, 28)
    assert env.action_space.shape == (batch_size, )
    assert env.reward_space.shape == (batch_size, )
    total_steps = 0
    for epoch in range(max_epochs):
        for obs, reward in env:
            assert obs.shape == (batch_size, 3, 28, 28)
            assert torch.all(obs == 0)
            assert reward.shape == (batch_size, )
            total_steps += 1

    assert total_steps == max_batches * max_epochs
Пример #9
0
 def __init__(self, env: gym.Env, f: Union[Callable, Compose]):
     if isinstance(f, list) and not callable(f):
         f = Compose(f)
     super().__init__(env, f=f)
     self.f: Transform
     # try:
     self.observation_space = self(self.env.observation_space)
     if has_tensor_support(self.env.observation_space):
         self.observation_space = add_tensor_support(self.observation_space)
Пример #10
0
    def __init__(self, env: gym.Env, f: Callable[[Union[gym.Env, Space]], Union[gym.Env, Space]]):
        if isinstance(f, list) and not callable(f):
            f = Compose(f)
        super().__init__(env)
        self.f: Compose = f
        # Modify the action space by applying the transform onto it.
        self.action_space = self.env.action_space

        if isinstance(self.f, Transform):
            self.action_space = self.f(self.env.action_space)
Пример #11
0
def test_compose_on_image_space():
    in_space = Image(0, 255, shape=(64, 64, 3), dtype=np.uint8)
    transform = Compose([Transforms.to_tensor, Transforms.three_channels])
    expected = Image(0, 1., shape=(3, 64, 64), dtype=np.float32) 
    actual = transform(in_space)
   
    assert actual == expected
    env = gym.make("MetaMonsterKong-v0")
    assert env.observation_space == gym.spaces.Box(0, 255, (64, 64, 3), np.uint8)
    assert env.observation_space == in_space
    wrapped_env = TransformObservation(env, transform)
    assert wrapped_env.observation_space == expected
def test_passive_environment_needs_actions_to_be_sent():
    """ Test the 'active dataloader' style interaction.
    """
    batch_size = 10
    transforms = Compose([Transforms.to_tensor, Transforms.three_channels])
    dataset = MNIST("data",
                    transform=Compose(
                        [Transforms.to_tensor, Transforms.three_channels]))
    max_samples = 105
    dataset = Subset(dataset, list(range(max_samples)))

    obs_space = Image(0, 255, (1, 28, 28), np.uint8)
    obs_space = transforms(obs_space)
    env = PassiveEnvironment(
        dataset,
        n_classes=10,
        batch_size=batch_size,
        observation_space=obs_space,
        pretend_to_be_active=True,
        strict=True,
    )

    with pytest.raises(RuntimeError):
        for i, (obs, _) in enumerate(env):
            pass

    env = PassiveEnvironment(
        dataset,
        n_classes=10,
        batch_size=batch_size,
        observation_space=obs_space,
        pretend_to_be_active=True,
    )
    for i, (obs, _) in enumerate(env):
        assert isinstance(obs, Tensor)
        action = env.action_space.sample()[:obs.shape[0]]
        rewards = env.send(action)
        assert rewards is not None
        assert rewards.shape[0] == action.shape[0]
def test_env_requires_reset_before_step():
    # from continuum.datasets import MNIST
    max_samples = 100
    batch_size = 5
    max_batches = max_samples // batch_size
    dataset = MNIST("data",
                    transform=Compose(
                        [Transforms.to_tensor, Transforms.three_channels]))
    dataset = Subset(dataset, list(range(max_samples)))

    env = PassiveEnvironment(dataset, n_classes=10, batch_size=batch_size)

    with pytest.raises(gym.error.ResetNeeded):
        env.step(env.action_space.sample())
Пример #14
0
    def __post_init__(self,
                      observation_space: gym.Space = None,
                      action_space: gym.Space = None,
                      reward_space: gym.Space = None):
        """ Initializes the fields of the setting that weren't set from the
        command-line.
        """
        logger.debug(f"__post_init__ of Setting")
        if len(self.train_transforms) == 1 and isinstance(self.train_transforms[0], list):
            self.train_transforms = self.train_transforms[0]
        if len(self.val_transforms) == 1 and isinstance(self.val_transforms[0], list):
            self.val_transforms = self.val_transforms[0]
        if len(self.test_transforms) == 1 and isinstance(self.test_transforms[0], list):
            self.test_transforms = self.test_transforms[0]

        # Actually compose the list of Transforms or callables into a single transform.
        self.train_transforms: Compose = Compose(self.train_transforms)
        self.val_transforms: Compose = Compose(self.val_transforms)
        self.test_transforms: Compose = Compose(self.test_transforms)

        LightningDataModule.__init__(self,
            train_transforms=self.train_transforms,
            val_transforms=self.val_transforms,
            test_transforms=self.test_transforms,
        )
        
        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_space = reward_space

        # TODO: It's a bit confusing to also have a `config` attribute on the
        # Setting. Might want to change this a bit.
        self.config: Config = None

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: Environment = None  # type: ignore
Пример #15
0
def test_passive_environment_as_dataloader():
    batch_size = 1
    transforms = Compose([Transforms.to_tensor, Transforms.three_channels])
    dataset = MNIST("data", transform=transforms)
    obs_space = Image(0, 255, (1, 28, 28), np.uint8)
    obs_space = transforms(obs_space)

    env: Iterable[Tuple[Tensor, Tensor]] = PassiveEnvironment(
        dataset, batch_size=batch_size, n_classes=10, observation_space=obs_space,
    )

    for x, y in env:
        assert x.shape == (batch_size, 3, 28, 28)
        x = x.permute(0, 2, 3, 1)
        assert y.tolist() == [5]
        break
Пример #16
0
    def test_multiple_epochs_env(self):
        max_epochs = 3
        max_samples = 100
        batch_size = 5
        max_batches = max_samples // batch_size
        dataset = MNIST("data",
                        transform=Compose(
                            [Transforms.to_tensor, Transforms.three_channels]))
        dataset = Subset(dataset, list(range(max_samples)))

        env = self.PassiveEnvironment(dataset,
                                      n_classes=10,
                                      batch_size=batch_size)

        assert env.observation_space.shape == (batch_size, 3, 28, 28)
        assert env.action_space.shape == (batch_size, )
        assert env.reward_space.shape == (batch_size, )

        env.seed(123)
        total_steps = 0
        for epoch in range(max_epochs):
            obs = env.reset()
            total_steps += 1

            assert obs.shape == (batch_size, 3, 28, 28)
            # Starting at 1 since reset() gives one observation already.
            for i in range(1, max_batches):
                obs, reward, done, info = env.step(env.action_space.sample())
                assert obs.shape == (batch_size, 3, 28, 28)
                assert reward.shape == (batch_size, )
                assert done == (i == max_batches - 1), i
                total_steps += 1
                if done:
                    break
            else:
                assert False, "Should have reached done=True!"
            assert i == max_batches - 1
        assert total_steps == max_batches * max_epochs

        env.close()
Пример #17
0
    def test_cant_iterate_after_closing_passive_env(self):
        max_epochs = 3
        max_samples = 200
        batch_size = 5
        max_batches = max_samples // batch_size
        dataset = MNIST("data",
                        transform=Compose(
                            [Transforms.to_tensor, Transforms.three_channels]))
        dataset = Subset(dataset, list(range(max_samples)))

        env = self.PassiveEnvironment(dataset,
                                      n_classes=10,
                                      batch_size=batch_size,
                                      num_workers=4)

        assert env.observation_space.shape == (batch_size, 3, 28, 28)
        assert env.action_space.shape == (batch_size, )
        assert env.reward_space.shape == (batch_size, )
        total_steps = 0
        for epoch in range(max_epochs):
            for obs, reward in env:
                assert obs.shape == (batch_size, 3, 28, 28)
                assert reward.shape == (batch_size, )
                total_steps += 1
        assert total_steps == max_batches * max_epochs

        env.close()

        with pytest.raises(gym.error.ClosedEnvironmentError):
            for _ in zip(range(3), env):
                pass

        with pytest.raises(gym.error.ClosedEnvironmentError):
            env.reset()

        with pytest.raises(gym.error.ClosedEnvironmentError):
            env.get_next_batch()

        with pytest.raises(gym.error.ClosedEnvironmentError):
            env.step(env.action_space.sample())
Пример #18
0
    def __init__(self, env: gym.Env, f: Union[Callable, Compose]):
        if isinstance(f, list) and not callable(f):
            f = Compose(f)
        super().__init__(env, f=f)
        self.f: Compose
        # Modify the reward space, if it exists.
        if hasattr(self.env, "reward_space"):
            self.reward_space = self.env.reward_space
        else:
            self.reward_space = spaces.Box(
                low=self.env.reward_range[0],
                high=self.env.reward_range[1],
                shape=(),
            )

        try:
            self.reward_space = self.f(self.reward_space)
            logger.debug(f"New reward space after transform: {self.reward_space}")
        except Exception as e:
            logger.warning(UserWarning(
                f"Don't know how the transform {self.f} will impact the "
                f"observation space! (Exception: {e})"
            ))
Пример #19
0
def test_issue_204():
    """ Test that reproduces the issue #204, which was that some zombie processes
    appeared to be created when iterating using an EnvironmentProxy.
    
    The issue appears to have been caused by calling `self.__environment.reset()` in
    `__iter__`, which I think caused another dataloader iterator to be created?
    """
    transforms = Compose([Transforms.to_tensor, Transforms.three_channels])

    batch_size = 2048
    num_workers = 12

    dataset = MNIST("data", transform=transforms)
    obs_space = Image(0, 255, (1, 28, 28), np.uint8)
    obs_space = transforms(obs_space)

    current_process = psutil.Process()
    print(
        f"Current process is using {current_process.num_threads()} threads, with "
        f" {len(current_process.children(recursive=True))} child processes.")
    starting_threads = current_process.num_threads()
    starting_processes = len(current_process.children(recursive=True))

    for use_wrapper in [False, True]:

        threads = current_process.num_threads()
        processes = len(current_process.children(recursive=True))
        assert threads == starting_threads
        assert processes == starting_processes

        env_type = ProxyPassiveEnvironment if use_wrapper else PassiveEnvironment
        env: Iterable[Tuple[Tensor, Tensor]] = env_type(
            dataset,
            batch_size=batch_size,
            n_classes=10,
            observation_space=obs_space,
            num_workers=num_workers,
            persistent_workers=True,
        )
        for i, _ in enumerate(env):
            threads = current_process.num_threads()
            processes = len(current_process.children(recursive=True))
            assert threads == starting_threads + num_workers
            assert processes == starting_processes + num_workers
            print(f"Current process is using {threads} threads, with "
                  f" {processes} child processes.")

        for i, _ in enumerate(env):
            threads = current_process.num_threads()
            processes = len(current_process.children(recursive=True))
            assert threads == starting_threads + num_workers
            assert processes == starting_processes + num_workers
            print(f"Current process is using {threads} threads, with "
                  f" {processes} child processes.")

        obs = env.reset()
        done = False
        while not done:
            obs, reward, done, info = env.step(env.action_space.sample())

            # env.render(mode="human")

            threads = current_process.num_threads()
            processes = len(current_process.children(recursive=True))
            if not done:
                assert threads == starting_threads + num_workers
                assert processes == starting_processes + num_workers
                print(f"Current process is using {threads} threads, with "
                      f" {processes} child processes.")

        env.close()

        import time
        # Need to give it a second (or so) to cleanup.
        time.sleep(1)

        threads = current_process.num_threads()
        processes = len(current_process.children(recursive=True))
        assert threads == starting_threads
        assert processes == starting_processes
Пример #20
0
    def __post_init__(
        self,
        observation_space: gym.Space = None,
        action_space: gym.Space = None,
        reward_space: gym.Space = None,
    ):
        """ Initializes the fields of the setting that weren't set from the
        command-line.
        """
        logger.debug("__post_init__ of Setting")
        # BUG: simple-parsing sometimes parses a list with a single item, itself the
        # list of transforms. Not sure if this still happens.

        def is_list_of_list(v: Any) -> bool:
            return isinstance(v, list) and len(v) == 1 and isinstance(v[0], list)

        if is_list_of_list(self.train_transforms):
            self.train_transforms = self.train_transforms[0]
        if is_list_of_list(self.val_transforms):
            self.val_transforms = self.val_transforms[0]
        if is_list_of_list(self.test_transforms):
            self.test_transforms = self.test_transforms[0]

        if all(
            t is None
            for t in [
                self.transforms,
                self.train_transforms,
                self.val_transforms,
                self.test_transforms,
            ]
        ):
            # Use these two transforms by default if no transforms are passed at all.
            # TODO: Remove this after the competition perhaps.
            self.transforms = Compose([Transforms.to_tensor, Transforms.three_channels])

        # If the constructor is called with just the `transforms` argument, like this:
        # <SomeSetting>(dataset="bob", transforms=foo_transform)
        # Then we use this value as the default for the train, val and test transforms.
        if self.transforms and not any(
            [self.train_transforms, self.val_transforms, self.test_transforms]
        ):
            if not isinstance(self.transforms, list):
                self.transforms = Compose([self.transforms])
            self.train_transforms = self.transforms.copy()
            self.val_transforms = self.transforms.copy()
            self.test_transforms = self.transforms.copy()

        if self.train_transforms is not None and not isinstance(
            self.train_transforms, list
        ):
            self.train_transforms = [self.train_transforms]

        if self.val_transforms is not None and not isinstance(
            self.val_transforms, list
        ):
            self.val_transforms = [self.val_transforms]

        if self.test_transforms is not None and not isinstance(
            self.test_transforms, list
        ):
            self.test_transforms = [self.test_transforms]

        # Actually compose the list of Transforms or callables into a single transform.
        self.train_transforms: Compose = Compose(self.train_transforms or [])
        self.val_transforms: Compose = Compose(self.val_transforms or [])
        self.test_transforms: Compose = Compose(self.test_transforms or [])

        LightningDataModule.__init__(
            self,
            train_transforms=self.train_transforms,
            val_transforms=self.val_transforms,
            test_transforms=self.test_transforms,
        )

        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_space = reward_space

        # TODO: It's a bit confusing to also have a `config` attribute on the
        # Setting. Might want to change this a bit.
        self.config: Config = None

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: Environment = None  # type: ignore
def test_split_batch_fn():
    # from continuum.datasets import MNIST
    batch_size = 5
    max_batches = 10

    def split_batch_fn(
        batch: Tuple[Tensor, Tensor, Tensor]
    ) -> Tuple[Tuple[Tensor, Tensor], Tensor]:
        x, y, t = batch
        return (x, t), y

    # dataset = MNIST("data", transform=Compose([Transforms.to_tensor, Transforms.three_channels]))
    from continuum import ClassIncremental
    from continuum.datasets import MNIST
    from continuum.tasks import split_train_val

    scenario = ClassIncremental(
        MNIST("data", download=True, train=True),
        increment=2,
        transformations=Compose(
            [Transforms.to_tensor, Transforms.three_channels]),
    )

    classes_per_task = scenario.nb_classes // scenario.nb_tasks
    print(f"Number of classes per task {classes_per_task}.")
    for i, task_dataset in enumerate(scenario):
        env = PassiveEnvironment(
            task_dataset,
            n_classes=classes_per_task,
            batch_size=batch_size,
            split_batch_fn=split_batch_fn,
            # Need to pass the observation space, in this case.
            observation_space=spaces.Tuple([
                spaces.Box(low=0, high=1, shape=(3, 28, 28)),
                spaces.Discrete(scenario.nb_tasks),  # task label
            ]),
            action_space=spaces.Box(
                low=np.array([i * classes_per_task]),
                high=np.array([(i + 1) * classes_per_task]),
                dtype=int,
            ),
        )
        assert spaces.Box(
            low=np.array([i * classes_per_task]),
            high=np.array([(i + 1) * classes_per_task]),
            dtype=int,
        ).shape == (1, )
        assert isinstance(env.observation_space[0], spaces.Box)
        assert env.observation_space[0].shape == (batch_size, 3, 28, 28)
        assert env.observation_space[1].shape == (batch_size, )
        assert env.action_space.shape == (batch_size, 1)
        assert env.reward_space.shape == (batch_size, 1)

        env.seed(123)

        obs = env.reset()
        assert len(obs) == 2
        x, t = obs
        assert x.shape == (batch_size, 3, 28, 28)
        assert t.shape == (batch_size, )

        obs, reward, done, info = env.step(env.action_space.sample())
        assert x.shape == (batch_size, 3, 28, 28)
        assert t.shape == (batch_size, )
        assert reward.shape == (batch_size, )
        assert not done

        env.close()