def _single_batch_a_func(self, features, scope, mode, context_fn=None, reuse=tf.AUTO_REUSE): """A state -> action regression function that expects a single batch dim.""" gripper_pose = features.gripper_pose if self._use_gripper_input else None with tf.variable_scope(scope, reuse=reuse, use_resource=True): with tf.variable_scope('state_features', reuse=reuse, use_resource=True): feature_points, end_points = vision_layers.BuildImagesToFeaturesModel( features.image, is_training=(mode == TRAIN), normalizer_fn=tf.contrib.layers.layer_norm) if context_fn: feature_points = context_fn(feature_points) fc_input = tf.concat([feature_points, gripper_pose], -1) outputs = {} if self._num_mixture_components > 1: dist_params = mdn.predict_mdn_params( fc_input, self._num_mixture_components, self._action_size, condition_sigmas=self._condition_mixture_stddev) gm = mdn.get_mixture_distribution( dist_params, self._num_mixture_components, self._action_size, self._output_mean if self._normalize_outputs else None) if self._output_mixture_sample: # Output a mixture sample as action. action = gm.sample() else: action = mdn.gaussian_mixture_approximate_mode(gm) outputs['dist_params'] = dist_params else: action, _ = vision_layers.BuildImageFeaturesToPoseModel( fc_input, num_outputs=self._action_size) action = self._output_mean + self._output_stddev * action outputs.update({ 'action': action, 'image': features.image, 'feature_points': feature_points, 'softmax': end_points['softmax'] }) return outputs
def test_predict_mdn_params(self, condition_sigmas): sample_size = 10 num_alphas = 5 inputs = tf.random.normal((2, 16)) with tf.variable_scope('test_scope'): dist_params = mdn.predict_mdn_params( inputs, num_alphas, sample_size, condition_sigmas=condition_sigmas) expected_num_params = num_alphas * (1 + 2 * sample_size) self.assertEqual(dist_params.shape.as_list(), [2, expected_num_params]) gm = mdn.get_mixture_distribution(dist_params, num_alphas, sample_size) stddev = gm.components_distribution.stddev() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) stddev_np = sess.run(stddev) if condition_sigmas: # Standard deviations should vary with input. self.assertNotAllClose(stddev_np[0], stddev_np[1]) else: # Standard deviations should *not* vary with input. self.assertAllClose(stddev_np[0], stddev_np[1])