def _embed_episode(self, episode_data): """Produces embeddings from episode data.""" demo_fp = meta_tfdata.multi_batch_apply( tec.embed_condition_images, 3, episode_data.features.image[:, 0:1, :, :], 'image_embedding') demo_inputs = tf.concat( [demo_fp, episode_data.features.gripper_pose[:, 0:1, :, :]], -1) embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, demo_inputs, self._fc_embed_size, 'fc_demo_reduce') if self._num_condition_samples_per_task > 1: con_success = 2 * episode_data.labels.success - 1 trial_embedding = meta_tfdata.multi_batch_apply( tec.embed_condition_images, 3, episode_data.features.image[:, 1:2, :, :], 'image_embedding') trial_embedding = tf.concat([ trial_embedding, episode_data.features.gripper_pose[:, 1:2, :, :], con_success[:, 1:2, :, :], tf.tile(embedding[:, :, None, :], [1, 1, 40, 1]) ], -1) trial_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, trial_embedding, self._fc_embed_size, 'fc_trial_reduce') embedding = tf.concat([embedding, trial_embedding], axis=-1) return embedding
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" condition_embedding = self._embed_episode(features.condition) gripper_pose = features.inference.features.gripper_pose fc_embedding = tf.tile( condition_embedding[Ellipsis, :self._fc_embed_size][:, :, None, :], [1, 1, self._episode_length, 1]) with tf.variable_scope('state_features', reuse=tf.AUTO_REUSE, use_resource=True): state_features, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImagesToFeaturesModel, 3, features.inference.features.image) if self._ignore_embedding: fc_inputs = tf.concat([state_features, gripper_pose], -1) else: fc_inputs = tf.concat([state_features, gripper_pose, fc_embedding], -1) outputs = {} aux_output_dim = 0 # We only predict end token for next step. if self._predict_end_weight > 0: aux_output_dim = 1 with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): action_params, end_token = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, num_outputs=None, aux_output_dim=aux_output_dim) action = self._action_decoder(params=action_params, output_size=self._num_waypoints * self._action_size) outputs.update({ 'action': action, # used for policy. 'condition_embedding': condition_embedding, }) if self._predict_end_weight > 0: outputs['end_token_logits'] = end_token outputs['end_token'] = tf.nn.sigmoid(end_token) outputs['action'] = tf.concat( [outputs['action'], outputs['end_token']], -1) if mode != PREDICT: # During training we embed the inference episodes to compute the triplet # loss between condition/inference embeddings. inference_embedding = self._embed_episode(features.inference) outputs['inference_embedding'] = inference_embedding return outputs
def _embed_episode(self, episode_data): """Produces embeddings from episode data.""" image_embedding = meta_tfdata.multi_batch_apply( tec.embed_condition_images, 3, episode_data.features.image, 'image_embedding') embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, image_embedding, self._fc_embed_size, 'fc_reduce') return tf.math.l2_normalize(embedding, axis=-1)
def model_train_fn(self, features, labels, inference_outputs, mode, config = None, params = None ): """Output learned loss if inner loop, or behavior clone if outer loop.""" if params and params.get('is_outer_loss', False): # Outer loss case: use standard RegressionModel loss. return self.loss_fn(labels, inference_outputs, mode, params) # Inner loss case: compute learned loss function. with tf.variable_scope( 'learned_loss', reuse=tf.AUTO_REUSE, use_resource=True): predicted_action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 2, inference_outputs['feature_points'], num_outputs=self._action_size) if self._learned_loss_conv1d_layers is None: return tf.losses.mean_squared_error(predicted_action, inference_outputs['action']) ll_input = tf.concat([ predicted_action, inference_outputs['feature_points'], inference_outputs['action'] ], -1) net = ll_input for num_filters in self._learned_loss_conv1d_layers[:-1]: net = tf.layers.conv1d( net, num_filters, 10, activation=tf.nn.relu, use_bias=False) net = tf.contrib.layers.layer_norm(net) net = tf.layers.conv1d(net, self._learned_loss_conv1d_layers[-1], 1) # 1x1 convolution. return tf.reduce_mean(tf.reduce_sum(tf.square(net), axis=(1, 2)))
def _preprocess_fn( self, features, labels, mode ): """Resize images and convert them from uint8 -> float32.""" if 'image' in features: ndim = len(features.image.shape) is_sequence = (ndim > 4) input_size = self._src_img_res target_size = self._crop_size features.original_image = features.image features.image = distortion.preprocess_image(features.image, mode, is_sequence, input_size, target_size) features.image = tf.image.convert_image_dtype(features.image, tf.float32) out_feature_spec = self.get_out_feature_specification(mode) if out_feature_spec.image.shape != features.image.shape: features.image = meta_tfdata.multi_batch_apply( tf.image.resize_images, 2, features.image, out_feature_spec.image.shape.as_list()[-3:-1]) if self._mixup_alpha > 0. and labels and mode == TRAIN: lmbda = tfp.distributions.Beta( self._mixup_alpha, self._mixup_alpha).sample() for key, x in features.items(): if x.dtype in FLOAT_DTYPES: features[key] = lmbda * x + (1-lmbda)*tf.reverse(x, axis=[0]) if labels is not None: for key, x in labels.items(): if x.dtype in FLOAT_DTYPES: labels[key] = lmbda * x + (1 - lmbda) * tf.reverse(x, axis=[0]) return features, labels
def __call__(self, params, output_size): """Applies the model. Args: params: Features conditioning the output distribution. output_size: Dimensionality of output distribution. Returns: tfp.Distribution object. """ # predict_mdn_params cannot handle meta-batches. dist_params = meta_tfdata.multi_batch_apply( predict_mdn_params, 3, params, self._num_mixture_components, output_size, condition_sigmas=False) # TODO(ejang): One limitation of a stateful module builder is that a wrapper # model (i.e. MAMLModel) may call this module multiple times and overwrite # self._gm before the appropriate loss() is called. A potential solution is # decoder objects that implement stateless functions (i.e. passing the state # to a handler). We should fix this later. self._gm = get_mixture_distribution(dist_params, self._num_mixture_components, output_size) action = gaussian_mixture_approximate_mode(self._gm) return action
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" condition_embedding = self._embed_episode(features.condition) gripper_pose = features.inference.features.gripper_pose fc_embedding = tf.tile(condition_embedding[:, :, None, :], [1, 1, self._episode_length, 1]) with tf.variable_scope('state_features', reuse=tf.AUTO_REUSE, use_resource=True): state_features, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImagesToFeaturesModel, 3, features.inference.features.image) if self._ignore_embedding: fc_inputs = tf.concat([state_features, gripper_pose], -1) else: fc_inputs = tf.concat([state_features, gripper_pose, fc_embedding], -1) outputs = {} with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): if self._num_mixture_components > 1: dist_params = meta_tfdata.multi_batch_apply( mdn.predict_mdn_params, 3, fc_inputs, self._num_mixture_components, self._action_size, False) outputs['dist_params'] = dist_params gm = mdn.get_mixture_distribution(dist_params, self._num_mixture_components, self._action_size) action = mdn.gaussian_mixture_approximate_mode(gm) else: action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, self._action_size) outputs['inference_output'] = action return outputs
def a_func(self, features, scope, mode, context_fn=None, reuse=tf.AUTO_REUSE, config=None, params=None): """Single step action predictor. See parent class.""" return meta_tfdata.multi_batch_apply(self.single_batch_a_func, 2, features, scope, mode, context_fn, reuse, config, params)
def _preprocess_fn(self, features, labels, mode): """Resize images and convert them from uint8 -> float32.""" ndim = len(features.image.shape) is_sequence = (ndim > 4) input_size = self._src_img_res target_size = self._crop_size features.image = distortion.preprocess_image(features.image, mode, is_sequence, input_size, target_size) out_feature_spec = self.get_out_feature_specification(mode) if out_feature_spec.image.shape != features.image.shape: features.image = meta_tfdata.multi_batch_apply( tf.image.resize_images, 2, tf.image.convert_image_dtype(features.image, tf.float32), out_feature_spec.image.shape.as_list()[-3:-1]) return features, labels
def a_func(self, features, scope, mode, context_fn=None, reuse=tf.AUTO_REUSE, config=None, params=None): """A (state) regression function. This function can return a stochastic or a deterministic tensor. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_spefication. scope: String specifying variable scope. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. context_fn: Optional python function that takes in features and returns new features of same shape. For merging information like in RL^2. reuse: Whether or not to reuse variables under variable scope 'scope'. config: Optional configuration object. Will receive what is passed to Estimator in config parameter, or the default config. Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Returns: outputs: A {key: Tensor} mapping. The key 'action' is required. """ del config, params return meta_tfdata.multi_batch_apply(self._single_batch_a_func, 2, features, scope, mode, context_fn, reuse)
def inference_network_fn(self, features, labels, mode, config=None, params=None): """See base class.""" inf_full_state_pose = features.inference.features.full_state_pose con_full_state_pose = features.condition.features.full_state_pose # Map success labels [0, 1] -> [-1, 1] con_success = 2 * features.condition.labels.success - 1 if self._retrial and con_full_state_pose.shape[1] != 2: raise ValueError('Unexpected shape {}.'.format( con_full_state_pose.shape)) if self._embed_type == 'temporal': fc_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, con_full_state_pose[:, 0:1, :, :], self._fc_embed_size, 'demo_embedding')[:, :, None, :] elif self._embed_type == 'mean': fc_embedding = con_full_state_pose[:, 0:1, -1:, :] else: raise ValueError('Invalid embed_type: {}.'.format( self._embed_type)) fc_embedding = tf.tile(fc_embedding, [1, 1, 40, 1]) if self._retrial: con_input = tf.concat([ con_full_state_pose[:, 1:2, :, :], con_success[:, 1:2, :, :], fc_embedding ], -1) if self._embed_type == 'mean': trial_embedding = meta_tfdata.multi_batch_apply( tec.embed_fullstate, 3, con_input, self._fc_embed_size, 'trial_embedding') trial_embedding = tf.reduce_mean(trial_embedding, -2) else: trial_embedding = meta_tfdata.multi_batch_apply( tec.reduce_temporal_embeddings, 2, con_input, self._fc_embed_size, 'trial_embedding') trial_embedding = tf.tile(trial_embedding[:, :, None, :], [1, 1, 40, 1]) fc_embedding = tf.concat([fc_embedding, trial_embedding], -1) if self._ignore_embedding: fc_inputs = inf_full_state_pose else: fc_inputs = [inf_full_state_pose, fc_embedding] if self._retrial: fc_inputs.append(con_success[:, 1:2, :, :]) fc_inputs = tf.concat(fc_inputs, -1) outputs = {} with tf.variable_scope('a_func', reuse=tf.AUTO_REUSE, use_resource=True): if self._num_mixture_components > 1: fc_inputs, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, num_outputs=None) dist_params = meta_tfdata.multi_batch_apply( mdn.predict_mdn_params, 3, fc_inputs, self._num_mixture_components, self._action_size, False) outputs['dist_params'] = dist_params gm = mdn.get_mixture_distribution(dist_params, self._num_mixture_components, self._action_size) action = mdn.gaussian_mixture_approximate_mode(gm) else: action, _ = meta_tfdata.multi_batch_apply( vision_layers.BuildImageFeaturesToPoseModel, 3, fc_inputs, self._action_size) outputs.update({ 'inference_output': action, }) return outputs