Пример #1
0
 def testOnlyTensorSpecIsSupported(self, dtype):
   sparse_spec = tf.SparseTensorSpec([1], tf.int32)
   with self.assertRaisesRegex(NotImplementedError, "not supported.*Sparse"):
     _ = tensor_spec.zero_spec_nest(sparse_spec)
   ragged_spec = tf.RaggedTensorSpec(ragged_rank=0, shape=[3, 5])
   with self.assertRaisesRegex(NotImplementedError, "not supported.*Ragged"):
     _ = tensor_spec.zero_spec_nest(ragged_spec)
Пример #2
0
 def _get_initial_state(self, batch_size):
     """The initial state, which is the prior."""
     network_state = tensor_spec.zero_spec_nest(
         self._policy_state_spec[0],
         outer_dims=None if batch_size is None else [batch_size])
     latent_state = self._model_network.sample_prior(batch_size)
     last_action = tensor_spec.zero_spec_nest(
         self._policy_state_spec[2],
         outer_dims=None if batch_size is None else [batch_size])
     return (network_state, latent_state, last_action)
Пример #3
0
    def testNestZeroWithOuterDims(self, dtype):
        nested_spec = example_nested_tensor_spec(dtype)
        zeros = tensor_spec.zero_spec_nest(nested_spec, outer_dims=[4])
        zeros_ = self.evaluate(zeros)

        def check_shape_and_zero(spec, value):
            self.assertEqual([4] + spec.shape, value.shape)
            self.assertTrue(np.all(value == 0))

        tf.nest.map_structure(check_shape_and_zero, nested_spec, zeros_)
Пример #4
0
    def testNestZero(self, dtype):
        if dtype == tf.string:
            self.skipTest("Not compatible with string type.")
        nested_spec = example_nested_tensor_spec(dtype)
        zeros = tensor_spec.zero_spec_nest(nested_spec)
        zeros_ = self.evaluate(zeros)

        def check_shape_and_zero(spec, value):
            self.assertEqual(spec.shape, value.shape)
            self.assertTrue(np.all(value == 0))

        tf.nest.map_structure(check_shape_and_zero, nested_spec, zeros_)
Пример #5
0
    def _reset(self):
        """
        Reset the state of the environment to an initial state. States are 
        represented as batched tensors. 
        
        Input:
            init -- type of states to create in a batch
                * 'vac': vacuum state
                * 'X+','X-','Y+','Y-','Z+','Z-': one of the cardinal states
                * 'random': sample batch of random states from 'X+','Y+','Z+'
                
        Output:
            TimeStep object (see tf-agents docs)
            
        """
        self.info = {}  # use to cache some intermediate results
        # Create initial state
        if self.init in ['vac', 'X+', 'X-', 'Y+', 'Y-', 'Z+', 'Z-']:
            psi = self.states[self.init]
            psi_batch = tf.stack([psi] * self.batch_size)
            self._state = self.info['psi_cached'] = psi_batch
            self._original = np.array([self.init] * self.batch_size)
        elif self.init == 'random':
            self._original = np.random.choice(['X+', 'Y+', 'Z+'],
                                              size=self.batch_size)
            psi_batch = [self.states[init] for init in self._original]
            psi_batch = tf.convert_to_tensor(psi_batch, dtype=c64)
            self._state = psi_batch

        # Bookkeeping of episode progress
        self._episode_ended = False
        self._elapsed_steps = 0
        self._episode_return = 0

        # Initialize history of horizon H with actions=0 and measurements=1
        self.history = tensor_spec.zero_spec_nest(
            self.action_spec(), outer_dims=(self.batch_size, ))
        for key in self.history.keys():
            self.history[key] = [self.history[key]] * self.H
        # Initialize history of horizon H*attn_step with measurements=1
        m = tf.ones(shape=[self.batch_size, 1])
        self.history['msmt'] = [m] * self.H * self.attn_step

        # Make observation of horizon H
        observation = {
            'msmt': tf.concat(self.history['msmt'][-self.H:], axis=1),
            'clock': tf.one_hot([0] * self.batch_size, self.T),
            'const': tf.ones(shape=[self.batch_size, 1])
        }

        self._current_time_step_ = ts.restart(observation, self.batch_size)
        return self.current_time_step()
Пример #6
0
    def _get_initial_state(self, batch_size):
        """Returns the initial state of the policy network.

    Args:
      batch_size: A constant or Tensor holding the batch size. Can be None, in
        which case the state will not have a batch dimension added.

    Returns:
      A nest of zero tensors matching the spec of the policy network state.
    """
        return tensor_spec.zero_spec_nest(
            self._policy_state_spec,
            outer_dims=None if batch_size is None else [batch_size])
Пример #7
0
    def testNestZeroWithOuterDimsTensor(self, dtype):
        if dtype == tf.string:
            self.skipTest("Not compatible with string type.")
        nested_spec = example_nested_tensor_spec(dtype)
        zeros = tensor_spec.zero_spec_nest(
            nested_spec, outer_dims=[tf.constant(8, dtype=tf.int32)])
        zeros_ = self.evaluate(zeros)

        def check_shape_and_zero(spec, value):
            self.assertEqual([8] + spec.shape, value.shape)
            self.assertTrue(np.all(value == 0))

        tf.nest.map_structure(check_shape_and_zero, nested_spec, zeros_)
Пример #8
0
    def get_initial_value_state(self,
                                batch_size: types.Int) -> types.NestedTensor:
        """Returns the initial state of the value network.

    Args:
      batch_size: A constant or Tensor holding the batch size. Can be None, in
        which case the state will not have a batch dimension added.

    Returns:
      A nest of zero tensors matching the spec of the value network state.
    """
        return tensor_spec.zero_spec_nest(
            self._value_network.state_spec,
            outer_dims=None if batch_size is None else [batch_size])
Пример #9
0
 def testEmptySpec(self, dtype):
     self.assertEqual((), tensor_spec.zero_spec_nest(()))
     self.assertEqual([], tensor_spec.zero_spec_nest([]))
        def rejection_sampling(sample_rejector):
            valid_batch_samples = tf.nest.map_structure(
                lambda spec: tf.TensorArray(spec.dtype, size=batch_size),
                self._action_spec)

            for b_indx in tf.range(batch_size):
                k = tf.constant(0)
                # pylint: disable=cell-var-from-loop
                valid_samples = tf.nest.map_structure(
                    lambda spec: tf.TensorArray(spec.dtype, size=num_samples),
                    self._action_spec)

                count = tf.constant(0)
                while count < self._max_rejection_iterations:
                    count += 1
                    mean_sample = tf.nest.map_structure(
                        lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0),
                        mean)
                    var_sample = tf.nest.map_structure(
                        lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0),
                        var)
                    if state is not None:
                        state_sample = tf.nest.map_structure(
                            lambda t: tf.expand_dims(tf.gather(t, b_indx),
                                                     axis=0), state)
                    else:
                        state_sample = None

                    samples = sample_fn(mean_sample, var_sample,
                                        state_sample)  # n, a

                    mask = sample_rejector(samples, state_sample)

                    mask = mask[0, ...]
                    mask_index = tf.where(mask)[:, 0]

                    num_mask = tf.shape(mask_index)[0]
                    if num_mask == 0:
                        continue

                    good_samples = tf.nest.map_structure(
                        lambda t: tf.gather(t, mask_index, axis=1)[0, ...],
                        samples)

                    for sample_idx in range(num_mask):
                        if k >= num_samples:
                            break
                        valid_samples = tf.nest.map_structure(
                            lambda gs, vs: vs.write(
                                k, gs[sample_idx:sample_idx + 1, ...]),
                            good_samples, valid_samples)
                        k += 1

                if k < num_samples:
                    zero_samples = tensor_spec.zero_spec_nest(
                        self._action_spec, outer_dims=(num_samples - k, ))
                    for sample_idx in range(num_samples - k):
                        valid_samples = tf.nest.map_structure(
                            lambda gs, vs: vs.write(
                                k, gs[sample_idx:sample_idx + 1, ...]),
                            zero_samples, valid_samples)

                valid_samples = tf.nest.map_structure(lambda vs: vs.concat(),
                                                      valid_samples)

                valid_batch_samples = tf.nest.map_structure(
                    lambda vbs, vs: vbs.write(b_indx, vs), valid_batch_samples,
                    valid_samples)

            samples_continuous = tf.nest.map_structure(lambda a: a.stack(),
                                                       valid_batch_samples)
            return samples_continuous
Пример #11
0
 def __init__(self, time_step_spec, action_spec):
     zero_action = tensor_spec.zero_spec_nest(action_spec)
     super(IdlePolicy, self).__init__(zero_action, time_step_spec,
                                      action_spec)