def inference_network_fn(self, features, labels, mode, config=None, params=None): is_training = mode == TRAIN condition_sequence_length = features.condition.features.image.shape[ 2].value inference_sequence_length = features.inference.features.image.shape[ 2].value # Conditioning only depends on video, does not have access to actual # trajectory. if not self._condition_gripper_pose: features.condition.features.gripper_pose = tf.zeros_like( features.condition.features.gripper_pose) def concat_across_time(key): # Assuming only 1 condition, 1 inference batch for now. return tf.concat([ features.condition.features[key][:, 0], features.inference.features[key][:, 0] ], axis=1) images = concat_across_time('image') aux_input = concat_across_time('gripper_pose') outputs = {} if self._num_mixture_components > 1: num_mus = self._action_size * self._num_mixture_components num_outputs = self._num_mixture_components + 2 * num_mus else: num_outputs = self._action_size poses, end_points = self._sequence_model_fn( images, aux_input, is_training=is_training, output_size=num_outputs, condition_sequence_length=condition_sequence_length, inference_sequence_length=inference_sequence_length) if self._num_mixture_components > 1: dist_params = poses[:, condition_sequence_length:] self._gm = mdn.get_mixture_distribution( dist_params, self._num_mixture_components, self._action_size) if self._greedy_action: inference_poses = self._gm.sample() else: inference_poses = mdn.gaussian_mixture_approximate_mode( self._gm) else: # Only the tail end of the sequence is used for inference. inference_poses = poses[:, condition_sequence_length:] outputs['inference_output'] = tf.expand_dims(inference_poses, 1) outputs.update(end_points) return outputs
def _single_batch_a_func(self, features, scope, mode, context_fn=None, reuse=tf.AUTO_REUSE): """A state -> action regression function that expects a single batch dim.""" gripper_pose = features.gripper_pose if self._use_gripper_input else None with tf.variable_scope(scope, reuse=reuse, use_resource=True): with tf.variable_scope('state_features', reuse=reuse, use_resource=True): feature_points, end_points = vision_layers.BuildImagesToFeaturesModel( features.image, is_training=(mode == TRAIN), normalizer_fn=tf.contrib.layers.layer_norm) if context_fn: feature_points = context_fn(feature_points) fc_input = tf.concat([feature_points, gripper_pose], -1) outputs = {} if self._num_mixture_components > 1: dist_params = mdn.predict_mdn_params( fc_input, self._num_mixture_components, self._action_size, condition_sigmas=self._condition_mixture_stddev) gm = mdn.get_mixture_distribution( dist_params, self._num_mixture_components, self._action_size, self._output_mean if self._normalize_outputs else None) if self._output_mixture_sample: # Output a mixture sample as action. action = gm.sample() else: action = mdn.gaussian_mixture_approximate_mode(gm) outputs['dist_params'] = dist_params else: action, _ = vision_layers.BuildImageFeaturesToPoseModel( fc_input, num_outputs=self._action_size) action = self._output_mean + self._output_stddev * action outputs.update({ 'action': action, 'image': features.image, 'feature_points': feature_points, 'softmax': end_points['softmax'] }) return outputs
def test_gaussian_mixture_approximate_mode(self): sample_size = 10 num_alphas = 5 # Manually set alphas to 1 in zero-th column and 0 elsewhere, making the # first component the most likely. alphas = tf.one_hot(2 * [0], num_alphas) mus = tf.random.normal((2, num_alphas, sample_size)) sigmas = tf.ones_like(mus) mix_dist = tfp.distributions.Categorical(logits=alphas) comp_dist = tfp.distributions.MultivariateNormalDiag( loc=mus, scale_diag=sigmas) gm = tfp.distributions.MixtureSameFamily( mixture_distribution=mix_dist, components_distribution=comp_dist) approximate_mode = mdn.gaussian_mixture_approximate_mode(gm) with self.test_session() as sess: approximate_mode_np, mus_np = sess.run([approximate_mode, mus]) # The approximate mode should be the mean of the zero-th (most likely) # component. self.assertAllClose(approximate_mode_np, mus_np[:, 0, :])
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" condition_embedding = self._embed_episode(features.condition) gripper_pose = features.inference.features.gripper_pose fc_embedding = tf.tile(condition_embedding[:, :, None, :], [1, 1, self._episode_length, 1]) with tf.variable_scope('state_features', reuse=tf.AUTO_REUSE, use_resource=True): state_features, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImagesToFeaturesModel, 3, features.inference.features.image) if self._ignore_embedding: fc_inputs = tf.concat([state_features, gripper_pose], -1) else: fc_inputs = tf.concat([state_features, gripper_pose, fc_embedding], -1) outputs = {} with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): if self._num_mixture_components > 1: dist_params = meta_tfdata.multi_batch_apply( mdn.predict_mdn_params, 3, fc_inputs, self._num_mixture_components, self._action_size, False) outputs['dist_params'] = dist_params gm = mdn.get_mixture_distribution(dist_params, self._num_mixture_components, self._action_size) action = mdn.gaussian_mixture_approximate_mode(gm) else: action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, self._action_size) outputs['inference_output'] = action return outputs
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" inf_full_state_pose = features.inference.features.full_state_pose con_full_state_pose = features.condition.features.full_state_pose # Map success labels [0, 1] -> [-1, 1] con_success = 2 * features.condition.labels.success - 1 if self._retrial and con_full_state_pose.shape[1] != 2: raise ValueError('Unexpected shape {}.'.format( con_full_state_pose.shape)) if self._embed_type == 'temporal': fc_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, con_full_state_pose[:, 0:1, :, :], self._fc_embed_size, 'demo_embedding')[:, :, None, :] elif self._embed_type == 'mean': fc_embedding = con_full_state_pose[:, 0:1, -1:, :] else: raise ValueError('Invalid embed_type: {}.'.format( self._embed_type)) fc_embedding = tf.tile(fc_embedding, [1, 1, 40, 1]) if self._retrial: con_input = tf.concat([ con_full_state_pose[:, 1:2, :, :], con_success[:, 1:2, :, :], fc_embedding ], -1) if self._embed_type == 'mean': trial_embedding = meta_tfdata.multi_batch_apply( tec.embed_fullstate, 3, con_input, self._fc_embed_size, 'trial_embedding') trial_embedding = tf.reduce_mean(trial_embedding, -2) else: trial_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, con_input, self._fc_embed_size, 'trial_embedding') trial_embedding = tf.tile(trial_embedding[:, :, None, :], [1, 1, 40, 1]) fc_embedding = tf.concat([fc_embedding, trial_embedding], -1) if self._ignore_embedding: fc_inputs = inf_full_state_pose else: fc_inputs = [inf_full_state_pose, fc_embedding] if self._retrial: fc_inputs.append(con_success[:, 1:2, :, :]) fc_inputs = tf.concat(fc_inputs, -1) outputs = {} with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): if self._num_mixture_components > 1: fc_inputs, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, num_outputs=None) dist_params = meta_tfdata.multi_batch_apply( mdn.predict_mdn_params, 3, fc_inputs, self._num_mixture_components, self._action_size, False) outputs['dist_params'] = dist_params gm = mdn.get_mixture_distribution(dist_params, self._num_mixture_components, self._action_size) action = mdn.gaussian_mixture_approximate_mode(gm) else: action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, self._action_size) outputs.update({ 'inference_output': action, }) return outputs