예제 #1
0
    def inference_network_fn(self,
                             features,
                             labels,
                             mode,
                             config=None,
                             params=None):
        is_training = mode == TRAIN
        condition_sequence_length = features.condition.features.image.shape[
            2].value
        inference_sequence_length = features.inference.features.image.shape[
            2].value

        # Conditioning only depends on video, does not have access to actual
        # trajectory.
        if not self._condition_gripper_pose:
            features.condition.features.gripper_pose = tf.zeros_like(
                features.condition.features.gripper_pose)

        def concat_across_time(key):
            # Assuming only 1 condition, 1 inference batch for now.
            return tf.concat([
                features.condition.features[key][:, 0],
                features.inference.features[key][:, 0]
            ],
                             axis=1)

        images = concat_across_time('image')
        aux_input = concat_across_time('gripper_pose')

        outputs = {}

        if self._num_mixture_components > 1:
            num_mus = self._action_size * self._num_mixture_components
            num_outputs = self._num_mixture_components + 2 * num_mus
        else:
            num_outputs = self._action_size
        poses, end_points = self._sequence_model_fn(
            images,
            aux_input,
            is_training=is_training,
            output_size=num_outputs,
            condition_sequence_length=condition_sequence_length,
            inference_sequence_length=inference_sequence_length)
        if self._num_mixture_components > 1:
            dist_params = poses[:, condition_sequence_length:]
            self._gm = mdn.get_mixture_distribution(
                dist_params, self._num_mixture_components, self._action_size)
            if self._greedy_action:
                inference_poses = self._gm.sample()
            else:
                inference_poses = mdn.gaussian_mixture_approximate_mode(
                    self._gm)
        else:
            # Only the tail end of the sequence is used for inference.
            inference_poses = poses[:, condition_sequence_length:]
        outputs['inference_output'] = tf.expand_dims(inference_poses, 1)
        outputs.update(end_points)
        return outputs
예제 #2
0
 def loss_fn(self, labels, inference_outputs, mode, params=None):
   """This implements outer loss and configurable inner losses."""
   if params and params.get('is_outer_loss', False):
     pass
   if self._num_mixture_components > 1:
     gm = mdn.get_mixture_distribution(
         inference_outputs['dist_params'], self._num_mixture_components,
         self._action_size,
         self._output_mean if self._normalize_outputs else None)
     return -tf.reduce_mean(gm.log_prob(labels.action))
   else:
     return self._outer_loss_multiplier * tf.losses.mean_squared_error(
         labels=labels.action, predictions=inference_outputs['action'])
예제 #3
0
    def _single_batch_a_func(self,
                             features,
                             scope,
                             mode,
                             context_fn=None,
                             reuse=tf.AUTO_REUSE):
        """A state -> action regression function that expects a single batch dim."""
        gripper_pose = features.gripper_pose if self._use_gripper_input else None
        with tf.variable_scope(scope, reuse=reuse, use_resource=True):
            with tf.variable_scope('state_features',
                                   reuse=reuse,
                                   use_resource=True):
                feature_points, end_points = vision_layers.BuildImagesToFeaturesModel(
                    features.image,
                    is_training=(mode == TRAIN),
                    normalizer_fn=tf.contrib.layers.layer_norm)

            if context_fn:
                feature_points = context_fn(feature_points)

            fc_input = tf.concat([feature_points, gripper_pose], -1)
            outputs = {}
            if self._num_mixture_components > 1:
                dist_params = mdn.predict_mdn_params(
                    fc_input,
                    self._num_mixture_components,
                    self._action_size,
                    condition_sigmas=self._condition_mixture_stddev)
                gm = mdn.get_mixture_distribution(
                    dist_params, self._num_mixture_components,
                    self._action_size,
                    self._output_mean if self._normalize_outputs else None)
                if self._output_mixture_sample:
                    # Output a mixture sample as action.
                    action = gm.sample()
                else:
                    action = mdn.gaussian_mixture_approximate_mode(gm)
                outputs['dist_params'] = dist_params
            else:
                action, _ = vision_layers.BuildImageFeaturesToPoseModel(
                    fc_input, num_outputs=self._action_size)
                action = self._output_mean + self._output_stddev * action
        outputs.update({
            'action': action,
            'image': features.image,
            'feature_points': feature_points,
            'softmax': end_points['softmax']
        })
        return outputs
예제 #4
0
    def inference_network_fn(self,
                             features,
                             labels,
                             mode,
                             config=None,
                             params=None):
        """See base class."""
        condition_embedding = self._embed_episode(features.condition)
        gripper_pose = features.inference.features.gripper_pose
        fc_embedding = tf.tile(condition_embedding[:, :, None, :],
                               [1, 1, self._episode_length, 1])
        with tf.variable_scope('state_features',
                               reuse=tf.AUTO_REUSE,
                               use_resource=True):
            state_features, _ = meta_tfdata.multi_batch_apply(
                vision_layers.BuildImagesToFeaturesModel, 3,
                features.inference.features.image)
        if self._ignore_embedding:
            fc_inputs = tf.concat([state_features, gripper_pose], -1)
        else:
            fc_inputs = tf.concat([state_features, gripper_pose, fc_embedding],
                                  -1)

        outputs = {}
        with tf.variable_scope('a_func',
                               reuse=tf.AUTO_REUSE,
                               use_resource=True):
            if self._num_mixture_components > 1:
                dist_params = meta_tfdata.multi_batch_apply(
                    mdn.predict_mdn_params, 3, fc_inputs,
                    self._num_mixture_components, self._action_size, False)
                outputs['dist_params'] = dist_params
                gm = mdn.get_mixture_distribution(dist_params,
                                                  self._num_mixture_components,
                                                  self._action_size)
                action = mdn.gaussian_mixture_approximate_mode(gm)
            else:
                action, _ = meta_tfdata.multi_batch_apply(
                    vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs,
                    self._action_size)
        outputs['inference_output'] = action
        return outputs
예제 #5
0
 def model_train_fn(self,
                    features,
                    labels,
                    inference_outputs,
                    mode,
                    config=None,
                    params=None):
     """Returns weighted sum of losses and individual losses. See base class."""
     if self._num_mixture_components > 1:
         gm = mdn.get_mixture_distribution(inference_outputs['dist_params'],
                                           self._num_mixture_components,
                                           self._action_size)
         bc_loss = -tf.reduce_mean(gm.log_prob(labels.action))
     else:
         bc_loss = tf.losses.mean_squared_error(
             labels=labels.action,
             predictions=inference_outputs['inference_output'])
     if mode == TRAIN and self.use_summaries(params):
         tf.summary.scalar('bc_loss', bc_loss)
     return bc_loss
예제 #6
0
  def test_predict_mdn_params(self, condition_sigmas):
    sample_size = 10
    num_alphas = 5
    inputs = tf.random.normal((2, 16))
    with tf.variable_scope('test_scope'):
      dist_params = mdn.predict_mdn_params(
          inputs, num_alphas, sample_size, condition_sigmas=condition_sigmas)
    expected_num_params = num_alphas * (1 + 2 * sample_size)
    self.assertEqual(dist_params.shape.as_list(), [2, expected_num_params])

    gm = mdn.get_mixture_distribution(dist_params, num_alphas, sample_size)
    stddev = gm.components_distribution.stddev()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      stddev_np = sess.run(stddev)
      if condition_sigmas:
        # Standard deviations should vary with input.
        self.assertNotAllClose(stddev_np[0], stddev_np[1])
      else:
        # Standard deviations should *not* vary with input.
        self.assertAllClose(stddev_np[0], stddev_np[1])
예제 #7
0
  def test_get_mixture_distribution(self):
    sample_size = 10
    num_alphas = 5
    batch_shape = (4, 2)
    alphas = tf.random.normal(batch_shape + (num_alphas,))
    mus = tf.random.normal(batch_shape + (sample_size * num_alphas,))
    sigmas = tf.random.normal(batch_shape + (sample_size * num_alphas,))
    params = tf.concat([alphas, mus, sigmas], -1)
    output_mean_np = np.random.normal(size=(sample_size,))
    gm = mdn.get_mixture_distribution(
        params, num_alphas, sample_size, output_mean=output_mean_np)
    self.assertEqual(gm.batch_shape, batch_shape)
    self.assertEqual(gm.event_shape, sample_size)

    # Check that the component means were translated by output_mean_np.
    component_means = gm.components_distribution.mean()
    with self.test_session() as sess:
      # Note: must get values from the same session run, since params will be
      # randomized across separate session runs.
      component_means_np, mus_np = sess.run([component_means, mus])
      mus_np = np.reshape(mus_np, component_means_np.shape)
      self.assertAllClose(component_means_np, mus_np + output_mean_np)
예제 #8
0
    def inference_network_fn(self,
                             features,
                             labels,
                             mode,
                             config=None,
                             params=None):
        """See base class."""
        inf_full_state_pose = features.inference.features.full_state_pose
        con_full_state_pose = features.condition.features.full_state_pose
        # Map success labels [0, 1] -> [-1, 1]
        con_success = 2 * features.condition.labels.success - 1
        if self._retrial and con_full_state_pose.shape[1] != 2:
            raise ValueError('Unexpected shape {}.'.format(
                con_full_state_pose.shape))
        if self._embed_type == 'temporal':
            fc_embedding = meta_tfdata.multi_batch_apply(
                tec.reduce_temporal_embeddings, 2,
                con_full_state_pose[:, 0:1, :, :], self._fc_embed_size,
                'demo_embedding')[:, :, None, :]
        elif self._embed_type == 'mean':
            fc_embedding = con_full_state_pose[:, 0:1, -1:, :]
        else:
            raise ValueError('Invalid embed_type: {}.'.format(
                self._embed_type))
        fc_embedding = tf.tile(fc_embedding, [1, 1, 40, 1])
        if self._retrial:
            con_input = tf.concat([
                con_full_state_pose[:, 1:2, :, :], con_success[:, 1:2, :, :],
                fc_embedding
            ], -1)
            if self._embed_type == 'mean':
                trial_embedding = meta_tfdata.multi_batch_apply(
                    tec.embed_fullstate, 3, con_input, self._fc_embed_size,
                    'trial_embedding')
                trial_embedding = tf.reduce_mean(trial_embedding, -2)
            else:
                trial_embedding = meta_tfdata.multi_batch_apply(
                    tec.reduce_temporal_embeddings, 2, con_input,
                    self._fc_embed_size, 'trial_embedding')
            trial_embedding = tf.tile(trial_embedding[:, :, None, :],
                                      [1, 1, 40, 1])
            fc_embedding = tf.concat([fc_embedding, trial_embedding], -1)
        if self._ignore_embedding:
            fc_inputs = inf_full_state_pose
        else:
            fc_inputs = [inf_full_state_pose, fc_embedding]
            if self._retrial:
                fc_inputs.append(con_success[:, 1:2, :, :])
            fc_inputs = tf.concat(fc_inputs, -1)
        outputs = {}
        with tf.variable_scope('a_func',
                               reuse=tf.AUTO_REUSE,
                               use_resource=True):
            if self._num_mixture_components > 1:
                fc_inputs, _ = meta_tfdata.multi_batch_apply(
                    vision_layers.BuildImageFeaturesToPoseModel,
                    3,
                    fc_inputs,
                    num_outputs=None)
                dist_params = meta_tfdata.multi_batch_apply(
                    mdn.predict_mdn_params, 3, fc_inputs,
                    self._num_mixture_components, self._action_size, False)
                outputs['dist_params'] = dist_params
                gm = mdn.get_mixture_distribution(dist_params,
                                                  self._num_mixture_components,
                                                  self._action_size)
                action = mdn.gaussian_mixture_approximate_mode(gm)
            else:
                action, _ = meta_tfdata.multi_batch_apply(
                    vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs,
                    self._action_size)

        outputs.update({
            'inference_output': action,
        })

        return outputs