def inference_network_fn(self, features, labels, mode, config=None, params=None): is_training = mode == TRAIN condition_sequence_length = features.condition.features.image.shape[ 2].value inference_sequence_length = features.inference.features.image.shape[ 2].value # Conditioning only depends on video, does not have access to actual # trajectory. if not self._condition_gripper_pose: features.condition.features.gripper_pose = tf.zeros_like( features.condition.features.gripper_pose) def concat_across_time(key): # Assuming only 1 condition, 1 inference batch for now. return tf.concat([ features.condition.features[key][:, 0], features.inference.features[key][:, 0] ], axis=1) images = concat_across_time('image') aux_input = concat_across_time('gripper_pose') outputs = {} if self._num_mixture_components > 1: num_mus = self._action_size * self._num_mixture_components num_outputs = self._num_mixture_components + 2 * num_mus else: num_outputs = self._action_size poses, end_points = self._sequence_model_fn( images, aux_input, is_training=is_training, output_size=num_outputs, condition_sequence_length=condition_sequence_length, inference_sequence_length=inference_sequence_length) if self._num_mixture_components > 1: dist_params = poses[:, condition_sequence_length:] self._gm = mdn.get_mixture_distribution( dist_params, self._num_mixture_components, self._action_size) if self._greedy_action: inference_poses = self._gm.sample() else: inference_poses = mdn.gaussian_mixture_approximate_mode( self._gm) else: # Only the tail end of the sequence is used for inference. inference_poses = poses[:, condition_sequence_length:] outputs['inference_output'] = tf.expand_dims(inference_poses, 1) outputs.update(end_points) return outputs
def loss_fn(self, labels, inference_outputs, mode, params=None): """This implements outer loss and configurable inner losses.""" if params and params.get('is_outer_loss', False): pass if self._num_mixture_components > 1: gm = mdn.get_mixture_distribution( inference_outputs['dist_params'], self._num_mixture_components, self._action_size, self._output_mean if self._normalize_outputs else None) return -tf.reduce_mean(gm.log_prob(labels.action)) else: return self._outer_loss_multiplier * tf.losses.mean_squared_error( labels=labels.action, predictions=inference_outputs['action'])
def _single_batch_a_func(self, features, scope, mode, context_fn=None, reuse=tf.AUTO_REUSE): """A state -> action regression function that expects a single batch dim.""" gripper_pose = features.gripper_pose if self._use_gripper_input else None with tf.variable_scope(scope, reuse=reuse, use_resource=True): with tf.variable_scope('state_features', reuse=reuse, use_resource=True): feature_points, end_points = vision_layers.BuildImagesToFeaturesModel( features.image, is_training=(mode == TRAIN), normalizer_fn=tf.contrib.layers.layer_norm) if context_fn: feature_points = context_fn(feature_points) fc_input = tf.concat([feature_points, gripper_pose], -1) outputs = {} if self._num_mixture_components > 1: dist_params = mdn.predict_mdn_params( fc_input, self._num_mixture_components, self._action_size, condition_sigmas=self._condition_mixture_stddev) gm = mdn.get_mixture_distribution( dist_params, self._num_mixture_components, self._action_size, self._output_mean if self._normalize_outputs else None) if self._output_mixture_sample: # Output a mixture sample as action. action = gm.sample() else: action = mdn.gaussian_mixture_approximate_mode(gm) outputs['dist_params'] = dist_params else: action, _ = vision_layers.BuildImageFeaturesToPoseModel( fc_input, num_outputs=self._action_size) action = self._output_mean + self._output_stddev * action outputs.update({ 'action': action, 'image': features.image, 'feature_points': feature_points, 'softmax': end_points['softmax'] }) return outputs
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" condition_embedding = self._embed_episode(features.condition) gripper_pose = features.inference.features.gripper_pose fc_embedding = tf.tile(condition_embedding[:, :, None, :], [1, 1, self._episode_length, 1]) with tf.variable_scope('state_features', reuse=tf.AUTO_REUSE, use_resource=True): state_features, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImagesToFeaturesModel, 3, features.inference.features.image) if self._ignore_embedding: fc_inputs = tf.concat([state_features, gripper_pose], -1) else: fc_inputs = tf.concat([state_features, gripper_pose, fc_embedding], -1) outputs = {} with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): if self._num_mixture_components > 1: dist_params = meta_tfdata.multi_batch_apply( mdn.predict_mdn_params, 3, fc_inputs, self._num_mixture_components, self._action_size, False) outputs['dist_params'] = dist_params gm = mdn.get_mixture_distribution(dist_params, self._num_mixture_components, self._action_size) action = mdn.gaussian_mixture_approximate_mode(gm) else: action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, self._action_size) outputs['inference_output'] = action return outputs
def model_train_fn(self, features, labels, inference_outputs, mode, config=None, params=None): """Returns weighted sum of losses and individual losses. See base class.""" if self._num_mixture_components > 1: gm = mdn.get_mixture_distribution(inference_outputs['dist_params'], self._num_mixture_components, self._action_size) bc_loss = -tf.reduce_mean(gm.log_prob(labels.action)) else: bc_loss = tf.losses.mean_squared_error( labels=labels.action, predictions=inference_outputs['inference_output']) if mode == TRAIN and self.use_summaries(params): tf.summary.scalar('bc_loss', bc_loss) return bc_loss
def test_predict_mdn_params(self, condition_sigmas): sample_size = 10 num_alphas = 5 inputs = tf.random.normal((2, 16)) with tf.variable_scope('test_scope'): dist_params = mdn.predict_mdn_params( inputs, num_alphas, sample_size, condition_sigmas=condition_sigmas) expected_num_params = num_alphas * (1 + 2 * sample_size) self.assertEqual(dist_params.shape.as_list(), [2, expected_num_params]) gm = mdn.get_mixture_distribution(dist_params, num_alphas, sample_size) stddev = gm.components_distribution.stddev() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) stddev_np = sess.run(stddev) if condition_sigmas: # Standard deviations should vary with input. self.assertNotAllClose(stddev_np[0], stddev_np[1]) else: # Standard deviations should *not* vary with input. self.assertAllClose(stddev_np[0], stddev_np[1])
def test_get_mixture_distribution(self): sample_size = 10 num_alphas = 5 batch_shape = (4, 2) alphas = tf.random.normal(batch_shape + (num_alphas,)) mus = tf.random.normal(batch_shape + (sample_size * num_alphas,)) sigmas = tf.random.normal(batch_shape + (sample_size * num_alphas,)) params = tf.concat([alphas, mus, sigmas], -1) output_mean_np = np.random.normal(size=(sample_size,)) gm = mdn.get_mixture_distribution( params, num_alphas, sample_size, output_mean=output_mean_np) self.assertEqual(gm.batch_shape, batch_shape) self.assertEqual(gm.event_shape, sample_size) # Check that the component means were translated by output_mean_np. component_means = gm.components_distribution.mean() with self.test_session() as sess: # Note: must get values from the same session run, since params will be # randomized across separate session runs. component_means_np, mus_np = sess.run([component_means, mus]) mus_np = np.reshape(mus_np, component_means_np.shape) self.assertAllClose(component_means_np, mus_np + output_mean_np)
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" inf_full_state_pose = features.inference.features.full_state_pose con_full_state_pose = features.condition.features.full_state_pose # Map success labels [0, 1] -> [-1, 1] con_success = 2 * features.condition.labels.success - 1 if self._retrial and con_full_state_pose.shape[1] != 2: raise ValueError('Unexpected shape {}.'.format( con_full_state_pose.shape)) if self._embed_type == 'temporal': fc_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, con_full_state_pose[:, 0:1, :, :], self._fc_embed_size, 'demo_embedding')[:, :, None, :] elif self._embed_type == 'mean': fc_embedding = con_full_state_pose[:, 0:1, -1:, :] else: raise ValueError('Invalid embed_type: {}.'.format( self._embed_type)) fc_embedding = tf.tile(fc_embedding, [1, 1, 40, 1]) if self._retrial: con_input = tf.concat([ con_full_state_pose[:, 1:2, :, :], con_success[:, 1:2, :, :], fc_embedding ], -1) if self._embed_type == 'mean': trial_embedding = meta_tfdata.multi_batch_apply( tec.embed_fullstate, 3, con_input, self._fc_embed_size, 'trial_embedding') trial_embedding = tf.reduce_mean(trial_embedding, -2) else: trial_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, con_input, self._fc_embed_size, 'trial_embedding') trial_embedding = tf.tile(trial_embedding[:, :, None, :], [1, 1, 40, 1]) fc_embedding = tf.concat([fc_embedding, trial_embedding], -1) if self._ignore_embedding: fc_inputs = inf_full_state_pose else: fc_inputs = [inf_full_state_pose, fc_embedding] if self._retrial: fc_inputs.append(con_success[:, 1:2, :, :]) fc_inputs = tf.concat(fc_inputs, -1) outputs = {} with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): if self._num_mixture_components > 1: fc_inputs, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, num_outputs=None) dist_params = meta_tfdata.multi_batch_apply( mdn.predict_mdn_params, 3, fc_inputs, self._num_mixture_components, self._action_size, False) outputs['dist_params'] = dist_params gm = mdn.get_mixture_distribution(dist_params, self._num_mixture_components, self._action_size) action = mdn.gaussian_mixture_approximate_mode(gm) else: action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, self._action_size) outputs.update({ 'inference_output': action, }) return outputs