def testNetworkStateWithPassthroughMask(self): # Create a wrapped network with `state_spec`. wrapped_network = WrappedNetwork(state_spec=self._state_spec) # Create a splitter network which passes the mask (`passthrough_mask=True`). splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network, passthrough_mask=True) # The mask is passed through. no_passthrough_splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network) # Apply the mask splitter network passing a state which is returned as the # output network state. _, network_state = splitter_network(self._observation_and_mask, network_state=self._network_state) _, no_passthrough_network_state = no_passthrough_splitter_network( self._observation_and_mask, network_state=self._network_state) # Check if the wrapped network received the correct network state. self.assertAllEqual(network_state, self._network_state) # Check if the state with and without `passthrough_mask` are the same. self.assertAllEqual(network_state, no_passthrough_network_state) # Check if the `state_spec` of the `splitter_network` is equal to the # `state_spec` of the `wrapped_network`. self.assertEqual(splitter_network.state_spec, wrapped_network.state_spec) self.assertEqual(no_passthrough_splitter_network.state_spec, wrapped_network.state_spec)
def testDistributionNetworkWithPassthroughMask(self): # Create a wrapped network. wrapped_network = WrappedDistributionNetwork( input_tensor_spec=self._observation_spec, state_spec=(), output_spec=self._output_spec, name='WrappedDistributionNetwork') # Create a splitter network optionally passing the `input_tensor_spec` which # always applies the mask (`passthrough_mask=True`). splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network, passthrough_mask=True) # The mask is passed through. # Apply splitter network which returns a distribution based on directly the # input `observation`. distribution, _ = splitter_network(self._observation_and_mask) # Check if distribution was properly created based on the input observation. logits = self._observation_and_mask['observation'] self.assertAllClose( logits, distribution.parameters['logits'], ) # Check if the wrapped network received the right mask . self.assertAllEqual(wrapped_network.mask, self._observation_and_mask['mask']) # Check if the `output_spec` of the `splitter_network` is equal to the # `output_spec` of the `wrapped_network`. self.assertEqual(splitter_network.output_spec, wrapped_network.output_spec)
def testDistributionNetwork(self): # Create a wrapped network. wrapped_network = WrappedDistributionNetwork( input_tensor_spec=self._observation_spec, state_spec=(), output_spec=self._output_spec, name='WrappedDistributionNetwork') # Create a splitter network which drops the mask (`passthrough_mask=False`). splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network) # Apply splitter network which returns a distribution based on directly the # input `observation`. distribution, _ = splitter_network(self._observation_and_mask) # Check if distribution was properly created based on the input observation. self.assertAllClose(distribution.parameters['logits'], self._observation_and_mask['observation']) # The wrapped network should *not* receive mask since the value of # `passthrough_mask=False` in the `splitter_network`. self.assertIsNone(wrapped_network.mask) # Check if the `output_spec` of the `splitter_network` is equal to the # `output_spec` of the `wrapped_network`. self.assertEqual(splitter_network.output_spec, wrapped_network.output_spec)
def testSimpleNetwork(self): # Create a wrapped network. wrapped_network = WrappedNetwork() # Create a splitter network which drops the mask (`passthrough_mask=False`). splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network) # Apply splitter network which returns the `observation` and `mask` received # by the `wrapped_network`. (observation, mask), _ = splitter_network(self._observation_and_mask) # Check if the wrapped network received the observation part of the input. self.assertAllClose(observation, self._observation_and_mask['observation']) # The wrapped network should *not* receive mask since the value of # `passthrough_mask=False` in the `splitter_network`. self.assertIsNone(mask)
def testSimpleNetworkWithPassthroughMask(self): # Create a wrapped network. wrapped_network = WrappedNetwork() # Create a splitter network which passes the mask (`passthrough_mask=True`). splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network, passthrough_mask=True) # The mask is passed through. # Apply splitter network which returns the `observation` and `mask` received # by the `wrapped_network`. (observation, mask), _ = splitter_network(self._observation_and_mask) # Check if the wrapped network received the observation part of the input. self.assertAllClose(observation, self._observation_and_mask['observation']) # Check if the wrapped network received the right mask . self.assertAllEqual(mask, self._observation_and_mask['mask'])
def testNetworkState(self): # Create a wrapped network with `state_spec`. wrapped_network = WrappedNetwork(state_spec=self._state_spec) # Create a splitter network which drops the mask (`passthrough_mask=False`). splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network) # Apply the mask splitter network passing a state which is returned as the # output network state. _, network_state = splitter_network(self._observation_and_mask, network_state=self._network_state) # Check if the wrapped network received the correct network state. self.assertAllEqual(network_state, self._network_state) # Check if the `state_spec` of the `splitter_network` is equal to the # `state_spec` of the `wrapped_network`. self.assertEqual(splitter_network.state_spec, wrapped_network.state_spec)
def testCopyUsesSameWrappedNetwork(self): # Create a wrapped network. wrapped_network = value_network.ValueNetwork(self._observation_spec, fc_layer_params=(2, )) # Create and build a `splitter_network`. splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network, passthrough_mask=True, input_tensor_spec=self._observation_and_mask_spec) splitter_network.create_variables() # Crate a copy of the splitter network while redefining the wrapped network. copied_splitter_network = splitter_network.copy( wrapped_network=wrapped_network) # Check if the underlying wrapped network objects are different. self.assertIs(copied_splitter_network._wrapped_network, splitter_network._wrapped_network)
def testCopyCreateNewInstanceOfNetworkIfNotRedefined(self): # Create a wrapped network. wrapped_network = value_network.ValueNetwork(self._observation_spec, fc_layer_params=(2, )) # Create and build a `splitter_network`. splitter_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn=self._splitter_fn, wrapped_network=wrapped_network, passthrough_mask=True, input_tensor_spec=self._observation_and_mask_spec) splitter_network.create_variables() # Copy and build the copied network. copied_splitter_network = splitter_network.copy() copied_splitter_network.create_variables() # Check if the underlying wrapped network objects are different. self.assertIsNot(copied_splitter_network._wrapped_network, splitter_network._wrapped_network)
def testPolicyStepWithActionMaskTurnedOn(self): # Creat specs with action constraints (mask). num_categories = 5 observation_tensor_spec = ( tensor_spec.TensorSpec(shape=(3, ), dtype=tf.int64, name='network_spec'), tensor_spec.TensorSpec(shape=(num_categories, ), dtype=tf.bool, name='mask_spec'), ) network_spec, _ = observation_tensor_spec action_tensor_spec = tensor_spec.BoundedTensorSpec((1, ), tf.int32, 0, num_categories - 1) # Create policy with splitter. def splitter_fn(observation_and_mask): return observation_and_mask[0], observation_and_mask[1] actor_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn, actor_distribution_network.ActorDistributionNetwork( network_spec, action_tensor_spec), passthrough_mask=True) value_network = mask_splitter_network.MaskSplitterNetwork( splitter_fn, value_net.ValueNetwork(network_spec)) policy = ppo_policy.PPOPolicy( ts.time_step_spec(observation_tensor_spec), action_tensor_spec, actor_network=actor_network, value_network=value_network, clip=False) # Take a step. mask = np.array([True, False, True, False, True], dtype=np.bool) self.assertLen(mask, num_categories) time_step = ts.TimeStep(step_type=tf.constant([1], dtype=tf.int32), reward=tf.constant([1], dtype=tf.float32), discount=tf.constant([1], dtype=tf.float32), observation=(tf.constant( [[1, 2, 3], [4, 5, 6]], dtype=tf.int64), tf.constant([mask.tolist()], dtype=tf.bool))) action_step = policy.action(time_step) # Check the shape and type of the resulted action step. self.assertEqual(action_step.action.shape.as_list(), [2, 1]) self.assertEqual(action_step.action.dtype, tf.int32) self.evaluate(tf.compat.v1.global_variables_initializer()) # Check the actions in general and with respect to masking. actions = self.evaluate(action_step.action) self.assertTrue(np.all(actions >= action_tensor_spec.minimum)) self.assertTrue(np.all(actions <= action_tensor_spec.maximum)) # Check the logits. logits = np.array(self.evaluate( action_step.info['dist_params']['logits']), dtype=np.float32) masked_actions = np.array(range(len(mask)))[~mask] self.assertTrue( np.all(logits[:, :, masked_actions] == np.finfo(np.float32).min)) valid_actions = np.array(range(len(mask)))[mask] self.assertTrue( np.all(logits[:, :, valid_actions] > np.finfo(np.float32).min))