コード例 #1
0
    def testNetworkStateWithPassthroughMask(self):
        # Create a wrapped network with `state_spec`.
        wrapped_network = WrappedNetwork(state_spec=self._state_spec)

        # Create a splitter network which passes the mask (`passthrough_mask=True`).
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True)  # The mask is passed through.
        no_passthrough_splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn, wrapped_network=wrapped_network)

        # Apply the mask splitter network passing a state which is returned as the
        # output network state.
        _, network_state = splitter_network(self._observation_and_mask,
                                            network_state=self._network_state)
        _, no_passthrough_network_state = no_passthrough_splitter_network(
            self._observation_and_mask, network_state=self._network_state)

        # Check if the wrapped network received the correct network state.
        self.assertAllEqual(network_state, self._network_state)
        # Check if the state with and without `passthrough_mask` are the same.
        self.assertAllEqual(network_state, no_passthrough_network_state)
        # Check if the `state_spec` of the `splitter_network` is equal to the
        # `state_spec` of the `wrapped_network`.
        self.assertEqual(splitter_network.state_spec,
                         wrapped_network.state_spec)
        self.assertEqual(no_passthrough_splitter_network.state_spec,
                         wrapped_network.state_spec)
コード例 #2
0
    def testDistributionNetworkWithPassthroughMask(self):
        # Create a wrapped network.
        wrapped_network = WrappedDistributionNetwork(
            input_tensor_spec=self._observation_spec,
            state_spec=(),
            output_spec=self._output_spec,
            name='WrappedDistributionNetwork')

        # Create a splitter network optionally passing the `input_tensor_spec` which
        # always applies the mask (`passthrough_mask=True`).
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True)  # The mask is passed through.

        # Apply splitter network which returns a distribution based on directly the
        # input `observation`.
        distribution, _ = splitter_network(self._observation_and_mask)

        # Check if distribution was properly created based on the input observation.
        logits = self._observation_and_mask['observation']
        self.assertAllClose(
            logits,
            distribution.parameters['logits'],
        )
        # Check if the wrapped network received the right mask .
        self.assertAllEqual(wrapped_network.mask,
                            self._observation_and_mask['mask'])
        # Check if the `output_spec` of the `splitter_network` is equal to the
        # `output_spec` of the `wrapped_network`.
        self.assertEqual(splitter_network.output_spec,
                         wrapped_network.output_spec)
コード例 #3
0
    def testDistributionNetwork(self):
        # Create a wrapped network.
        wrapped_network = WrappedDistributionNetwork(
            input_tensor_spec=self._observation_spec,
            state_spec=(),
            output_spec=self._output_spec,
            name='WrappedDistributionNetwork')

        # Create a splitter network which drops the mask (`passthrough_mask=False`).
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn, wrapped_network=wrapped_network)

        # Apply splitter network which returns a distribution based on directly the
        # input `observation`.
        distribution, _ = splitter_network(self._observation_and_mask)

        # Check if distribution was properly created based on the input observation.
        self.assertAllClose(distribution.parameters['logits'],
                            self._observation_and_mask['observation'])
        # The wrapped network should *not* receive mask since the value of
        # `passthrough_mask=False` in the `splitter_network`.
        self.assertIsNone(wrapped_network.mask)
        # Check if the `output_spec` of the `splitter_network` is equal to the
        # `output_spec` of the `wrapped_network`.
        self.assertEqual(splitter_network.output_spec,
                         wrapped_network.output_spec)
コード例 #4
0
    def testSimpleNetwork(self):
        # Create a wrapped network.
        wrapped_network = WrappedNetwork()

        # Create a splitter network which drops the mask (`passthrough_mask=False`).
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn, wrapped_network=wrapped_network)

        # Apply splitter network which returns the `observation` and `mask` received
        # by the `wrapped_network`.
        (observation, mask), _ = splitter_network(self._observation_and_mask)

        # Check if the wrapped network received the observation part of the input.
        self.assertAllClose(observation,
                            self._observation_and_mask['observation'])
        # The wrapped network should *not* receive mask since the value of
        # `passthrough_mask=False` in the `splitter_network`.
        self.assertIsNone(mask)
コード例 #5
0
    def testSimpleNetworkWithPassthroughMask(self):
        # Create a wrapped network.
        wrapped_network = WrappedNetwork()

        # Create a splitter network which passes the mask (`passthrough_mask=True`).
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True)  # The mask is passed through.

        # Apply splitter network which returns the `observation` and `mask` received
        # by the `wrapped_network`.
        (observation, mask), _ = splitter_network(self._observation_and_mask)

        # Check if the wrapped network received the observation part of the input.
        self.assertAllClose(observation,
                            self._observation_and_mask['observation'])
        # Check if the wrapped network received the right mask .
        self.assertAllEqual(mask, self._observation_and_mask['mask'])
コード例 #6
0
    def testNetworkState(self):
        # Create a wrapped network with `state_spec`.
        wrapped_network = WrappedNetwork(state_spec=self._state_spec)

        # Create a splitter network which drops the mask (`passthrough_mask=False`).
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn, wrapped_network=wrapped_network)

        # Apply the mask splitter network passing a state which is returned as the
        # output network state.
        _, network_state = splitter_network(self._observation_and_mask,
                                            network_state=self._network_state)

        # Check if the wrapped network received the correct network state.
        self.assertAllEqual(network_state, self._network_state)
        # Check if the `state_spec` of the `splitter_network` is equal to the
        # `state_spec` of the `wrapped_network`.
        self.assertEqual(splitter_network.state_spec,
                         wrapped_network.state_spec)
コード例 #7
0
    def testCopyUsesSameWrappedNetwork(self):
        # Create a wrapped network.
        wrapped_network = value_network.ValueNetwork(self._observation_spec,
                                                     fc_layer_params=(2, ))

        # Create and build a `splitter_network`.
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True,
            input_tensor_spec=self._observation_and_mask_spec)
        splitter_network.create_variables()

        # Crate a copy of the splitter network while redefining the wrapped network.
        copied_splitter_network = splitter_network.copy(
            wrapped_network=wrapped_network)

        # Check if the underlying wrapped network objects are different.
        self.assertIs(copied_splitter_network._wrapped_network,
                      splitter_network._wrapped_network)
コード例 #8
0
    def testCopyCreateNewInstanceOfNetworkIfNotRedefined(self):
        # Create a wrapped network.
        wrapped_network = value_network.ValueNetwork(self._observation_spec,
                                                     fc_layer_params=(2, ))

        # Create and build a `splitter_network`.
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True,
            input_tensor_spec=self._observation_and_mask_spec)
        splitter_network.create_variables()

        # Copy and build the copied network.
        copied_splitter_network = splitter_network.copy()
        copied_splitter_network.create_variables()

        # Check if the underlying wrapped network objects are different.
        self.assertIsNot(copied_splitter_network._wrapped_network,
                         splitter_network._wrapped_network)
コード例 #9
0
ファイル: ppo_policy_test.py プロジェクト: wuzh07/agents
    def testPolicyStepWithActionMaskTurnedOn(self):
        # Creat specs with action constraints (mask).
        num_categories = 5
        observation_tensor_spec = (
            tensor_spec.TensorSpec(shape=(3, ),
                                   dtype=tf.int64,
                                   name='network_spec'),
            tensor_spec.TensorSpec(shape=(num_categories, ),
                                   dtype=tf.bool,
                                   name='mask_spec'),
        )
        network_spec, _ = observation_tensor_spec
        action_tensor_spec = tensor_spec.BoundedTensorSpec((1, ), tf.int32, 0,
                                                           num_categories - 1)

        # Create policy with splitter.
        def splitter_fn(observation_and_mask):
            return observation_and_mask[0], observation_and_mask[1]

        actor_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn,
            actor_distribution_network.ActorDistributionNetwork(
                network_spec, action_tensor_spec),
            passthrough_mask=True)
        value_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn, value_net.ValueNetwork(network_spec))
        policy = ppo_policy.PPOPolicy(
            ts.time_step_spec(observation_tensor_spec),
            action_tensor_spec,
            actor_network=actor_network,
            value_network=value_network,
            clip=False)

        # Take a step.
        mask = np.array([True, False, True, False, True], dtype=np.bool)
        self.assertLen(mask, num_categories)
        time_step = ts.TimeStep(step_type=tf.constant([1], dtype=tf.int32),
                                reward=tf.constant([1], dtype=tf.float32),
                                discount=tf.constant([1], dtype=tf.float32),
                                observation=(tf.constant(
                                    [[1, 2, 3], [4, 5, 6]], dtype=tf.int64),
                                             tf.constant([mask.tolist()],
                                                         dtype=tf.bool)))
        action_step = policy.action(time_step)

        # Check the shape and type of the resulted action step.
        self.assertEqual(action_step.action.shape.as_list(), [2, 1])
        self.assertEqual(action_step.action.dtype, tf.int32)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        # Check the actions in general and with respect to masking.
        actions = self.evaluate(action_step.action)
        self.assertTrue(np.all(actions >= action_tensor_spec.minimum))
        self.assertTrue(np.all(actions <= action_tensor_spec.maximum))

        # Check the logits.
        logits = np.array(self.evaluate(
            action_step.info['dist_params']['logits']),
                          dtype=np.float32)
        masked_actions = np.array(range(len(mask)))[~mask]
        self.assertTrue(
            np.all(logits[:, :, masked_actions] == np.finfo(np.float32).min))
        valid_actions = np.array(range(len(mask)))[mask]
        self.assertTrue(
            np.all(logits[:, :, valid_actions] > np.finfo(np.float32).min))