Exemplo n.º 1
0
    def test_dqfd(self):
        # Create a fake environment to test with.
        # TODO(b/152596848): Allow DQN to deal with integer observations.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Build demonstrations.
        dummy_action = np.zeros((), dtype=np.int32)
        recorder = bsuite_demonstrations.DemonstrationRecorder()
        timestep = environment.reset()
        while timestep.step_type is not dm_env.StepType.LAST:
            recorder.step(timestep, dummy_action)
            timestep = environment.step(dummy_action)
        recorder.step(timestep, dummy_action)
        recorder.record_episode()

        # Construct the agent.
        agent = dqfd.DQfD(environment_spec=spec,
                          network=_make_network(spec.actions),
                          demonstration_dataset=recorder.make_tf_dataset(),
                          demonstration_ratio=0.5,
                          batch_size=10,
                          samples_per_insert=2,
                          min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=10)
Exemplo n.º 2
0
    def test_mcts(self):
        # Create a fake environment to test with.
        num_actions = 5
        environment = fakes.DiscreteEnvironment(num_actions=num_actions,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        network = snt.Sequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50]),
            networks.PolicyValueHead(spec.actions.num_values),
        ])
        model = simulator.Simulator(environment)
        optimizer = snt.optimizers.Adam(1e-3)

        # Construct the agent.
        agent = mcts.MCTS(environment_spec=spec,
                          network=network,
                          model=model,
                          optimizer=optimizer,
                          n_step=1,
                          discount=1.,
                          replay_capacity=100,
                          num_simulations=10,
                          batch_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Exemplo n.º 3
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Make network purely functional
        network_hk = hk.without_apply_rng(hk.transform(network,
                                                       apply_rng=True))
        dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

        network = networks_lib.FeedForwardNetwork(
            init=lambda rng: network_hk.init(rng, dummy_obs),
            apply=network_hk.apply)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Exemplo n.º 4
0
    def test_reset(self):
        """Ensure that noop starts `reset` steps the environment multiple times."""
        noop_action = 0
        noop_max = 10
        seed = 24

        base_env = fakes.DiscreteEnvironment(action_dtype=np.int64,
                                             obs_dtype=np.int64,
                                             reward_spec=specs.Array(
                                                 dtype=np.float64, shape=()))
        mock_step_fn = mock.MagicMock()
        expected_num_step_calls = np.random.RandomState(seed).randint(
            noop_max + 1)

        with mock.patch.object(base_env, 'step', mock_step_fn):
            env = wrappers.NoopStartsWrapper(
                base_env,
                noop_action=noop_action,
                noop_max=noop_max,
                seed=seed,
            )
            env.reset()

            # Test environment step called with noop action as part of wrapper.reset
            mock_step_fn.assert_called_with(noop_action)
            self.assertEqual(mock_step_fn.call_count, expected_num_step_calls)
            self.assertEqual(mock_step_fn.call_args, ((noop_action, ), {}))
Exemplo n.º 5
0
    def run_ppo_agent(self, make_networks_fn):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        distribution_value_networks = make_networks_fn(spec)
        ppo_networks = ppo.make_ppo_networks(distribution_value_networks)
        config = ppo.PPOConfig(unroll_length=4,
                               num_epochs=2,
                               num_minibatches=2)
        workdir = self.create_tempdir()
        counter = counting.Counter()
        logger = loggers.make_default_logger('learner')
        # Construct the agent.
        agent = ppo.PPO(
            spec=spec,
            networks=ppo_networks,
            config=config,
            seed=0,
            workdir=workdir.full_path,
            normalize_input=True,
            counter=counter,
            logger=logger,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
Exemplo n.º 6
0
    def test_impala(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x, s):
            model = MyNetwork(spec.actions.num_values)
            return model(x, s)

        def initial_state_fn(batch_size: Optional[int] = None):
            model = MyNetwork(spec.actions.num_values)
            return model.initial_state(batch_size)

        # Construct the agent.
        agent = impala.IMPALA(
            environment_spec=spec,
            network=network,
            initial_state_fn=initial_state_fn,
            sequence_length=3,
            sequence_period=3,
            batch_size=6,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Exemplo n.º 7
0
    def test_r2d2(self):
        # Create a fake environment to test with.
        # TODO(b/152596848): Allow R2D2 to deal with integer observations.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 4),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        agent = r2d2.R2D2(
            environment_spec=spec,
            network=SimpleNetwork(spec.actions),
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            burn_in_length=2,
            trace_length=6,
            replay_period=4,
            checkpoint=False,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=5)
Exemplo n.º 8
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Exemplo n.º 9
0
    def test_environment_loop(self):
        # Create the actor/environment and stick them in a loop.
        environment = fakes.DiscreteEnvironment(episode_length=10)
        actor = fakes.Actor(specs.make_environment_spec(environment))
        loop = environment_loop.EnvironmentLoop(environment, actor)

        # Run the loop. There should be episode_length+1 update calls per episode.
        loop.run(num_episodes=10)
        self.assertEqual(actor.num_updates, 100)
Exemplo n.º 10
0
    def test_pickle_unpickle(self):
        test_env = base.EnvironmentWrapper(
            environment=fakes.DiscreteEnvironment())

        test_env_pickled = pickle.dumps(test_env)
        test_env_restored = pickle.loads(test_env_pickled)
        self.assertEqual(
            test_env.observation_spec(),
            test_env_restored.observation_spec(),
        )
Exemplo n.º 11
0
    def test_raises_value_error(self):
        """Ensure that wrapper raises error if noop_max is <0."""
        base_env = fakes.DiscreteEnvironment(action_dtype=np.int64,
                                             obs_dtype=np.int64,
                                             reward_spec=specs.Array(
                                                 dtype=np.float64, shape=()))

        with self.assertRaises(ValueError):
            wrappers.NoopStartsWrapper(base_env,
                                       noop_action=0,
                                       noop_max=-1,
                                       seed=24)
Exemplo n.º 12
0
    def test_discrete_actions(self, loss_name):
        with chex.fake_pmap_and_jit():

            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.DiscreteEnvironment(num_actions=10,
                                                    num_observations=100,
                                                    obs_shape=(10, ),
                                                    obs_dtype=np.float32)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec, discrete_actions=True)

            def logp_fn(logits, actions):
                max_logits = jnp.max(logits, axis=-1, keepdims=True)
                logits = logits - max_logits
                logits_actions = jnp.sum(
                    jax.nn.one_hot(actions, spec.actions.num_values) * logits,
                    axis=-1)

                log_prob = logits_actions - special.logsumexp(logits, axis=-1)
                return log_prob

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=logp_fn)

            elif loss_name == 'rcal':
                base_loss_fn = bc.logp(logp_fn=logp_fn)
                loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1)

            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Exemplo n.º 13
0
    def test_impala(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def forward_fn(x, s):
            model = MyNetwork(spec.actions.num_values)
            return model(x, s)

        def initial_state_fn(batch_size: Optional[int] = None):
            model = MyNetwork(spec.actions.num_values)
            return model.initial_state(batch_size)

        def unroll_fn(inputs, state):
            model = MyNetwork(spec.actions.num_values)
            return hk.static_unroll(model, inputs, state)

        # We pass pure, Haiku-agnostic functions to the agent.
        forward_fn_transformed = hk.without_apply_rng(
            hk.transform(forward_fn, apply_rng=True))
        unroll_fn_transformed = hk.without_apply_rng(
            hk.transform(unroll_fn, apply_rng=True))
        initial_state_fn_transformed = hk.without_apply_rng(
            hk.transform(initial_state_fn, apply_rng=True))

        # Construct the agent.
        config = impala_agent.IMPALAConfig(
            sequence_length=3,
            sequence_period=3,
            batch_size=6,
        )
        agent = impala.IMPALAFromConfig(
            environment_spec=spec,
            forward_fn=forward_fn_transformed.apply,
            initial_state_init_fn=initial_state_fn_transformed.init,
            initial_state_fn=initial_state_fn_transformed.apply,
            unroll_init_fn=unroll_fn_transformed.init,
            unroll_fn=unroll_fn_transformed.apply,
            config=config,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Exemplo n.º 14
0
  def test_discrete(self):
    env = wrappers.SinglePrecisionWrapper(
        fakes.DiscreteEnvironment(
            action_dtype=np.int64, obs_dtype=np.int64, reward_dtype=np.float64))

    self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.int32))
    self.assertTrue(np.issubdtype(env.action_spec().dtype, np.int32))
    self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32))
    self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32))

    timestep = env.reset()
    self.assertEqual(timestep.reward, None)
    self.assertEqual(timestep.discount, None)
    self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32))

    timestep = env.step(0)
    self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32))
    self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32))
    self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32))
Exemplo n.º 15
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=_make_network(spec.actions),
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Exemplo n.º 16
0
def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None,
                         reward_spec: Optional[types.NestedSpec] = None):
  """Common setup code that, unlike self.setUp, takes arguments.

  Args:
    discount_spec: None, or a (nested) specs.BoundedArray.
    reward_spec: None, or a (nested) specs.Array.
  Returns:
    environment, actor, loop
  """
  env_kwargs = {'episode_length': EPISODE_LENGTH}
  if discount_spec:
    env_kwargs['discount_spec'] = discount_spec
  if reward_spec:
    env_kwargs['reward_spec'] = reward_spec

  environment = fakes.DiscreteEnvironment(**env_kwargs)
  actor = fakes.Actor(specs.make_environment_spec(environment))
  loop = environment_loop.EnvironmentLoop(environment, actor)
  return actor, loop
Exemplo n.º 17
0
    def test_impala(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        agent = impala.IMPALA(
            environment_spec=spec,
            network=_make_network(spec.actions),
            sequence_length=3,
            sequence_period=3,
            batch_size=6,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Exemplo n.º 18
0
  def test_full_learner(self):
    # Create dataset.
    environment = fakes.DiscreteEnvironment(
        num_actions=5,
        num_observations=10,
        obs_dtype=np.float32,
        episode_length=10)
    spec = specs.make_environment_spec(environment)
    dataset = fakes.transition_dataset(environment).batch(
        2, drop_remainder=True)

    # Build network.
    network = networks.IQNNetwork(
        torso=_make_torso_network(num_outputs=2),
        head=_make_head_network(num_outputs=spec.actions.num_values),
        latent_dim=2,
        num_quantile_samples=1)
    tf2_utils.create_variables(network, [spec.observations])

    # Build learner.
    counter = counting.Counter()
    learner = iqn.IQNLearner(
        network=network,
        target_network=copy.deepcopy(network),
        dataset=dataset,
        learning_rate=1e-4,
        discount=0.99,
        importance_sampling_exponent=0.2,
        target_update_period=1,
        counter=counter)

    # Run a learner step.
    learner.step()

    # Check counts from IQN learner.
    counts = counter.get_counts()
    self.assertEqual(1, counts['steps'])

    # Check learner state.
    self.assertEqual(1, learner.state['num_steps'].numpy())
Exemplo n.º 19
0
  def test_full_learner(self):
    # Create dataset.
    environment = fakes.DiscreteEnvironment(
        num_actions=5,
        num_observations=10,
        obs_dtype=np.float32,
        episode_length=10)
    spec = specs.make_environment_spec(environment)
    dataset = fakes.transition_dataset(environment).batch(2)

    # Build network.
    g_network = _make_network(spec.actions)
    q_network = _make_network(spec.actions)
    network = discrete_networks.DiscreteFilteredQNetwork(g_network=g_network,
                                                         q_network=q_network,
                                                         threshold=0.5)
    tf2_utils.create_variables(network, [spec.observations])

    # Build learner.
    counter = counting.Counter()
    learner = bcq.DiscreteBCQLearner(
        network=network,
        dataset=dataset,
        learning_rate=1e-4,
        discount=0.99,
        importance_sampling_exponent=0.2,
        target_update_period=100,
        counter=counter)

    # Run a learner step.
    learner.step()

    # Check counts from BC and BCQ learners.
    counts = counter.get_counts()
    self.assertEqual(1, counts['bc_steps'])
    self.assertEqual(1, counts['bcq_steps'])

    # Check learner state.
    self.assertEqual(1, learner.state['bc_num_steps'].numpy())
    self.assertEqual(1, learner.state['bcq_num_steps'].numpy())
Exemplo n.º 20
0
    def test_r2d3(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Build demonstrations.
        dummy_action = np.zeros((), dtype=np.int32)
        recorder = bsuite_demonstrations.DemonstrationRecorder()
        timestep = environment.reset()
        while timestep.step_type is not dm_env.StepType.LAST:
            recorder.step(timestep, dummy_action)
            timestep = environment.step(dummy_action)
        recorder.step(timestep, dummy_action)
        recorder.record_episode()

        # Construct the agent.
        agent = r2d3.R2D3(
            environment_spec=spec,
            network=SimpleNetwork(spec.actions),
            target_network=SimpleNetwork(spec.actions),
            demonstration_dataset=recorder.make_tf_dataset(),
            demonstration_ratio=0.5,
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            burn_in_length=2,
            trace_length=6,
            replay_period=4,
            checkpoint=False,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=5)
Exemplo n.º 21
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))
Exemplo n.º 22
0
 def setUp(self):
     super().setUp()
     # Create the actor/environment and stick them in a loop.
     environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH)
     self.actor = fakes.Actor(specs.make_environment_spec(environment))
     self.loop = environment_loop.EnvironmentLoop(environment, self.actor)
Exemplo n.º 23
0
 def test_deepcopy(self):
     test_env = base.EnvironmentWrapper(
         environment=fakes.DiscreteEnvironment())
     copied_env = copy.deepcopy(test_env)
     del copied_env