def actions_and_log_probs(self, observations): """Compute actions and log probabilities together. We need this functions to avoid numerical issues coming out of the squashing bijector (`tfp.bijectors.Tanh`). Ideally this would be avoided by using caching of the bijector and then computing actions and log probs separately, but that's currently not possible due to the issue in the graph mode (i.e. within `tf.function`) bijector caching. This method could be removed once the caching works. For more, see: https://github.com/tensorflow/probability/issues/840 """ observations = self._filter_observations(observations) first_observation = tree.flatten(observations)[0] first_input_rank = tf.size(tree.flatten(self._input_shapes)[0]) batch_shape = tf.shape(first_observation)[:-first_input_rank] shifts, scales = self.shift_and_scale_model(observations) actions = self.action_distribution.sample( batch_shape, bijector_kwargs={'scale': {'scale': scales}, 'shift': {'shift': shifts}}) log_probs = self.action_distribution.log_prob( actions, bijector_kwargs={'scale': {'scale': scales}, 'shift': {'shift': shifts}} )[..., tf.newaxis] return actions, log_probs
def actions(self, observations): first_observation = tree.flatten(observations)[0] first_input_rank = tf.size(tree.flatten(self._input_shapes)[0]) batch_shape = tf.shape(first_observation)[:-first_input_rank] actions = self.distribution.sample(batch_shape) return actions
def actions(self, observations): """Compute actions for given observations.""" observations = self._filter_observations(observations) first_observation = tree.flatten(observations)[0] first_input_rank = tf.size(tree.flatten(self._input_shapes)[0]) batch_shape = tf.shape(first_observation)[:-first_input_rank] shifts, scales = self.shift_and_scale_model(observations) actions = self.action_distribution.sample( batch_shape, bijector_kwargs={'scale': {'scale': scales}, 'shift': {'shift': shifts}}) return actions
def add_samples(self, samples): num_samples = tree.flatten(samples)[0].shape[0] assert (('episode_index_forwards' in samples.keys()) is ('episode_index_backwards' in samples.keys())) if 'episode_index_forwards' not in samples.keys(): samples['episode_index_forwards'] = np.full( (num_samples, *self.fields['episode_index_forwards'].shape), self.fields['episode_index_forwards'].default_value, dtype=self.fields['episode_index_forwards'].dtype) samples['episode_index_backwards'] = np.full( (num_samples, *self.fields['episode_index_backwards'].shape), self.fields['episode_index_backwards'].default_value, dtype=self.fields['episode_index_backwards'].dtype) index = np.arange( self._pointer, self._pointer + num_samples) % self._max_size def add_sample(path, data, new_values, field): assert new_values.shape[0] == num_samples, ( new_values.shape, num_samples) data[index] = new_values tree.map_structure_with_path( add_sample, self.data, samples, self.fields) self._advance(num_samples)
def preprocess(x): """Cast to float, normalize, and concatenate images along last axis.""" x = tree.map_structure( lambda image: tf.image.convert_image_dtype(image, tf.float32), x) x = tree.flatten(x) x = tf.concat(x, axis=-1) x = (tf.image.convert_image_dtype(x, tf.float32) - 0.5) * 2.0 return x
def load_experience(self, experience_path): with gzip.open(experience_path, 'rb') as f: latest_samples = pickle.load(f) num_samples = tree.flatten(latest_samples)[0].shape[0] def assert_shape(data): assert data.shape[0] == num_samples, data.shape tree.map_structure(assert_shape, latest_samples) self.add_samples(latest_samples) self._samples_since_save = 0
def add_path(self, path): path = path.copy() path_length = tree.flatten(path)[0].shape[0] path.update({ 'episode_index_forwards': np.arange( path_length, dtype=self.fields['episode_index_forwards'].dtype )[..., np.newaxis], 'episode_index_backwards': np.arange( path_length, dtype=self.fields['episode_index_backwards'].dtype )[::-1, np.newaxis], }) return self.add_samples(path)
def actions(self, observations): if 0 < self._smoothing_alpha: raise NotImplementedError( "TODO(hartikainen): Smoothing alpha temporarily dropped on tf2" " migration. Should add it back. See:" " https://github.com/rail-berkeley/softlearning/blob/46374df0294b9b5f6dbe65b9471ec491a82b6944/softlearning/policies/base_policy.py#L80") observations = self._filter_observations(observations) batch_shape = tf.shape(tree.flatten(observations)[0])[:-1] actions = self.action_distribution.sample( batch_shape, bijector_kwargs={ self.flow_model.name: {'observations': observations} }) return actions
def cast_and_concat(x): x = tree.map_structure(lambda element: tf.cast(element, tf.float32), x) x = tree.flatten(x) x = tf.concat(x, axis=-1) return x
def flatten_input_structure(inputs): inputs_flat = tree.flatten(inputs) return inputs_flat