Example #1
0
    def actions_and_log_probs(self, observations):
        """Compute actions and log probabilities together.

        We need this functions to avoid numerical issues coming out of the
        squashing bijector (`tfp.bijectors.Tanh`). Ideally this would be
        avoided by using caching of the bijector and then computing actions
        and log probs separately, but that's currently not possible due to the
        issue in the graph mode (i.e. within `tf.function`) bijector caching.
        This method could be removed once the caching works. For more, see:
        https://github.com/tensorflow/probability/issues/840
        """
        observations = self._filter_observations(observations)

        first_observation = tree.flatten(observations)[0]
        first_input_rank = tf.size(tree.flatten(self._input_shapes)[0])
        batch_shape = tf.shape(first_observation)[:-first_input_rank]

        shifts, scales = self.shift_and_scale_model(observations)
        actions = self.action_distribution.sample(
            batch_shape,
            bijector_kwargs={'scale': {'scale': scales},
                             'shift': {'shift': shifts}})
        log_probs = self.action_distribution.log_prob(
            actions,
            bijector_kwargs={'scale': {'scale': scales},
                             'shift': {'shift': shifts}}
        )[..., tf.newaxis]

        return actions, log_probs
Example #2
0
    def actions(self, observations):
        first_observation = tree.flatten(observations)[0]
        first_input_rank = tf.size(tree.flatten(self._input_shapes)[0])
        batch_shape = tf.shape(first_observation)[:-first_input_rank]

        actions = self.distribution.sample(batch_shape)

        return actions
Example #3
0
    def actions(self, observations):
        """Compute actions for given observations."""
        observations = self._filter_observations(observations)

        first_observation = tree.flatten(observations)[0]
        first_input_rank = tf.size(tree.flatten(self._input_shapes)[0])
        batch_shape = tf.shape(first_observation)[:-first_input_rank]

        shifts, scales = self.shift_and_scale_model(observations)
        actions = self.action_distribution.sample(
            batch_shape,
            bijector_kwargs={'scale': {'scale': scales},
                             'shift': {'shift': shifts}})

        return actions
Example #4
0
    def add_samples(self, samples):
        num_samples = tree.flatten(samples)[0].shape[0]

        assert (('episode_index_forwards' in samples.keys())
                is ('episode_index_backwards' in samples.keys()))
        if 'episode_index_forwards' not in samples.keys():
            samples['episode_index_forwards'] = np.full(
                (num_samples, *self.fields['episode_index_forwards'].shape),
                self.fields['episode_index_forwards'].default_value,
                dtype=self.fields['episode_index_forwards'].dtype)
            samples['episode_index_backwards'] = np.full(
                (num_samples, *self.fields['episode_index_backwards'].shape),
                self.fields['episode_index_backwards'].default_value,
                dtype=self.fields['episode_index_backwards'].dtype)

        index = np.arange(
            self._pointer, self._pointer + num_samples) % self._max_size

        def add_sample(path, data, new_values, field):
            assert new_values.shape[0] == num_samples, (
                new_values.shape, num_samples)
            data[index] = new_values

        tree.map_structure_with_path(
            add_sample, self.data, samples, self.fields)

        self._advance(num_samples)
Example #5
0
 def preprocess(x):
     """Cast to float, normalize, and concatenate images along last axis."""
     x = tree.map_structure(
         lambda image: tf.image.convert_image_dtype(image, tf.float32), x)
     x = tree.flatten(x)
     x = tf.concat(x, axis=-1)
     x = (tf.image.convert_image_dtype(x, tf.float32) - 0.5) * 2.0
     return x
Example #6
0
    def load_experience(self, experience_path):
        with gzip.open(experience_path, 'rb') as f:
            latest_samples = pickle.load(f)

        num_samples = tree.flatten(latest_samples)[0].shape[0]

        def assert_shape(data):
            assert data.shape[0] == num_samples, data.shape

        tree.map_structure(assert_shape, latest_samples)

        self.add_samples(latest_samples)
        self._samples_since_save = 0
Example #7
0
    def add_path(self, path):
        path = path.copy()
        path_length = tree.flatten(path)[0].shape[0]
        path.update({
            'episode_index_forwards': np.arange(
                path_length,
                dtype=self.fields['episode_index_forwards'].dtype
            )[..., np.newaxis],
            'episode_index_backwards': np.arange(
                path_length,
                dtype=self.fields['episode_index_backwards'].dtype
            )[::-1, np.newaxis],
        })

        return self.add_samples(path)
Example #8
0
    def actions(self, observations):
        if 0 < self._smoothing_alpha:
            raise NotImplementedError(
                "TODO(hartikainen): Smoothing alpha temporarily dropped on tf2"
                " migration. Should add it back. See:"
                " https://github.com/rail-berkeley/softlearning/blob/46374df0294b9b5f6dbe65b9471ec491a82b6944/softlearning/policies/base_policy.py#L80")

        observations = self._filter_observations(observations)

        batch_shape = tf.shape(tree.flatten(observations)[0])[:-1]
        actions = self.action_distribution.sample(
            batch_shape, bijector_kwargs={
                self.flow_model.name: {'observations': observations}
            })

        return actions
Example #9
0
def cast_and_concat(x):
    x = tree.map_structure(lambda element: tf.cast(element, tf.float32), x)
    x = tree.flatten(x)
    x = tf.concat(x, axis=-1)
    return x
Example #10
0
def flatten_input_structure(inputs):
    inputs_flat = tree.flatten(inputs)
    return inputs_flat