def testOnlyTensorSpecIsSupported(self, dtype): sparse_spec = tf.SparseTensorSpec([1], tf.int32) with self.assertRaisesRegex(NotImplementedError, "not supported.*Sparse"): _ = tensor_spec.zero_spec_nest(sparse_spec) ragged_spec = tf.RaggedTensorSpec(ragged_rank=0, shape=[3, 5]) with self.assertRaisesRegex(NotImplementedError, "not supported.*Ragged"): _ = tensor_spec.zero_spec_nest(ragged_spec)
def _get_initial_state(self, batch_size): """The initial state, which is the prior.""" network_state = tensor_spec.zero_spec_nest( self._policy_state_spec[0], outer_dims=None if batch_size is None else [batch_size]) latent_state = self._model_network.sample_prior(batch_size) last_action = tensor_spec.zero_spec_nest( self._policy_state_spec[2], outer_dims=None if batch_size is None else [batch_size]) return (network_state, latent_state, last_action)
def testNestZeroWithOuterDims(self, dtype): nested_spec = example_nested_tensor_spec(dtype) zeros = tensor_spec.zero_spec_nest(nested_spec, outer_dims=[4]) zeros_ = self.evaluate(zeros) def check_shape_and_zero(spec, value): self.assertEqual([4] + spec.shape, value.shape) self.assertTrue(np.all(value == 0)) tf.nest.map_structure(check_shape_and_zero, nested_spec, zeros_)
def testNestZero(self, dtype): if dtype == tf.string: self.skipTest("Not compatible with string type.") nested_spec = example_nested_tensor_spec(dtype) zeros = tensor_spec.zero_spec_nest(nested_spec) zeros_ = self.evaluate(zeros) def check_shape_and_zero(spec, value): self.assertEqual(spec.shape, value.shape) self.assertTrue(np.all(value == 0)) tf.nest.map_structure(check_shape_and_zero, nested_spec, zeros_)
def _reset(self): """ Reset the state of the environment to an initial state. States are represented as batched tensors. Input: init -- type of states to create in a batch * 'vac': vacuum state * 'X+','X-','Y+','Y-','Z+','Z-': one of the cardinal states * 'random': sample batch of random states from 'X+','Y+','Z+' Output: TimeStep object (see tf-agents docs) """ self.info = {} # use to cache some intermediate results # Create initial state if self.init in ['vac', 'X+', 'X-', 'Y+', 'Y-', 'Z+', 'Z-']: psi = self.states[self.init] psi_batch = tf.stack([psi] * self.batch_size) self._state = self.info['psi_cached'] = psi_batch self._original = np.array([self.init] * self.batch_size) elif self.init == 'random': self._original = np.random.choice(['X+', 'Y+', 'Z+'], size=self.batch_size) psi_batch = [self.states[init] for init in self._original] psi_batch = tf.convert_to_tensor(psi_batch, dtype=c64) self._state = psi_batch # Bookkeeping of episode progress self._episode_ended = False self._elapsed_steps = 0 self._episode_return = 0 # Initialize history of horizon H with actions=0 and measurements=1 self.history = tensor_spec.zero_spec_nest( self.action_spec(), outer_dims=(self.batch_size, )) for key in self.history.keys(): self.history[key] = [self.history[key]] * self.H # Initialize history of horizon H*attn_step with measurements=1 m = tf.ones(shape=[self.batch_size, 1]) self.history['msmt'] = [m] * self.H * self.attn_step # Make observation of horizon H observation = { 'msmt': tf.concat(self.history['msmt'][-self.H:], axis=1), 'clock': tf.one_hot([0] * self.batch_size, self.T), 'const': tf.ones(shape=[self.batch_size, 1]) } self._current_time_step_ = ts.restart(observation, self.batch_size) return self.current_time_step()
def _get_initial_state(self, batch_size): """Returns the initial state of the policy network. Args: batch_size: A constant or Tensor holding the batch size. Can be None, in which case the state will not have a batch dimension added. Returns: A nest of zero tensors matching the spec of the policy network state. """ return tensor_spec.zero_spec_nest( self._policy_state_spec, outer_dims=None if batch_size is None else [batch_size])
def testNestZeroWithOuterDimsTensor(self, dtype): if dtype == tf.string: self.skipTest("Not compatible with string type.") nested_spec = example_nested_tensor_spec(dtype) zeros = tensor_spec.zero_spec_nest( nested_spec, outer_dims=[tf.constant(8, dtype=tf.int32)]) zeros_ = self.evaluate(zeros) def check_shape_and_zero(spec, value): self.assertEqual([8] + spec.shape, value.shape) self.assertTrue(np.all(value == 0)) tf.nest.map_structure(check_shape_and_zero, nested_spec, zeros_)
def get_initial_value_state(self, batch_size: types.Int) -> types.NestedTensor: """Returns the initial state of the value network. Args: batch_size: A constant or Tensor holding the batch size. Can be None, in which case the state will not have a batch dimension added. Returns: A nest of zero tensors matching the spec of the value network state. """ return tensor_spec.zero_spec_nest( self._value_network.state_spec, outer_dims=None if batch_size is None else [batch_size])
def testEmptySpec(self, dtype): self.assertEqual((), tensor_spec.zero_spec_nest(())) self.assertEqual([], tensor_spec.zero_spec_nest([]))
def rejection_sampling(sample_rejector): valid_batch_samples = tf.nest.map_structure( lambda spec: tf.TensorArray(spec.dtype, size=batch_size), self._action_spec) for b_indx in tf.range(batch_size): k = tf.constant(0) # pylint: disable=cell-var-from-loop valid_samples = tf.nest.map_structure( lambda spec: tf.TensorArray(spec.dtype, size=num_samples), self._action_spec) count = tf.constant(0) while count < self._max_rejection_iterations: count += 1 mean_sample = tf.nest.map_structure( lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), mean) var_sample = tf.nest.map_structure( lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), var) if state is not None: state_sample = tf.nest.map_structure( lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), state) else: state_sample = None samples = sample_fn(mean_sample, var_sample, state_sample) # n, a mask = sample_rejector(samples, state_sample) mask = mask[0, ...] mask_index = tf.where(mask)[:, 0] num_mask = tf.shape(mask_index)[0] if num_mask == 0: continue good_samples = tf.nest.map_structure( lambda t: tf.gather(t, mask_index, axis=1)[0, ...], samples) for sample_idx in range(num_mask): if k >= num_samples: break valid_samples = tf.nest.map_structure( lambda gs, vs: vs.write( k, gs[sample_idx:sample_idx + 1, ...]), good_samples, valid_samples) k += 1 if k < num_samples: zero_samples = tensor_spec.zero_spec_nest( self._action_spec, outer_dims=(num_samples - k, )) for sample_idx in range(num_samples - k): valid_samples = tf.nest.map_structure( lambda gs, vs: vs.write( k, gs[sample_idx:sample_idx + 1, ...]), zero_samples, valid_samples) valid_samples = tf.nest.map_structure(lambda vs: vs.concat(), valid_samples) valid_batch_samples = tf.nest.map_structure( lambda vbs, vs: vbs.write(b_indx, vs), valid_batch_samples, valid_samples) samples_continuous = tf.nest.map_structure(lambda a: a.stack(), valid_batch_samples) return samples_continuous
def __init__(self, time_step_spec, action_spec): zero_action = tensor_spec.zero_spec_nest(action_spec) super(IdlePolicy, self).__init__(zero_action, time_step_spec, action_spec)