def train(self, max_num_steps, time_step, policy_state): """Perform on-policy training with `max_num_steps`. Args: max_num_steps (int): stops after so many environment steps. Is the total number of steps from all the individual environment in the bached environments including StepType.LAST steps. time_step (ActionTimeStep): optional initial time_step. If None, it will use self.get_initial_time_step(). Elements should be shape [batch_size, ...]. policy_state (nested Tensor): optional initial state for the policy. Returns: None """ maximum_iterations = math.ceil( max_num_steps / (self._env.batch_size * (self._train_interval + (self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP)))) [time_step, policy_state] = tf.while_loop(cond=lambda *_: True, body=self._iter, loop_vars=[time_step, policy_state], maximum_iterations=maximum_iterations, back_prop=False, name="driver_loop") return time_step, policy_state
def em_loop( num_updates, e_step, m_step, variables, ): """Expectation-maximization of objective_fn wrt variables for num_updates.""" def _body(step, preupdate_vars): train_predictions_, responsibilities_ = e_step(preupdate_vars) updated_vars = m_step(preupdate_vars, train_predictions_, responsibilities_) return step + 1, updated_vars def _cond(step, *args): del args return step < num_updates step = tf.Variable(0, trainable=False, name='inner_step_counter') loop_vars = (step, variables) step, updated_vars = tf.while_loop(cond=_cond, body=_body, loop_vars=loop_vars, swap_memory=True) return updated_vars
def grad_fn(dy): """Compute gradients using a while loop to save memory.""" support_keys_id = tf.identity(support_keys) support_values_id = tf.identity(support_values) initial = (0, tf.zeros(tf.shape(query_queries)[1:], dtype=dy.dtype)[tf.newaxis, :][:zero_dim], tf.zeros(tf.shape(query_values)[1:], dtype=dy.dtype)[tf.newaxis, :][:zero_dim], tf.zeros(tf.shape(support_keys_id), dtype=dy.dtype), tf.zeros(tf.shape(support_values_id), dtype=dy.dtype)) def loop_body(idx, qq_grad, qv_grad, sk_grad, sv_grad): """Compute gradients for a single query.""" qq = query_queries[idx:idx + 1] qv = query_values[idx:idx + 1] x = self._get_dist(qq, qv, support_keys_id, support_values_id, labels) grads = tf.gradients( x, [qq, qv, support_keys_id, support_values_id], grad_ys=dy[:, idx:idx + 1]) qq_grad = tf.concat([qq_grad, grads[0]], axis=0) qv_grad = tf.concat([qv_grad, grads[1]], axis=0) sk_grad += grads[2] sv_grad += grads[3] return (idx + 1, qq_grad, qv_grad, sk_grad, sv_grad) agg_grads = tf.while_loop( lambda *arg: arg[0] < tf.shape(query_queries)[0], loop_body, initial, parallel_iterations=1) return agg_grads[1:] + (None,)
def rollout(self, max_num_steps, time_step, policy_state): counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size maximum_iterations = math.ceil(max_num_steps / self._env.batch_size) def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=maximum_iterations, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure( create_ta, self._training_info_spec._replace( rollout_info=nest_utils.to_distribution_param_spec( self._training_info_spec.rollout_info))) [counter, time_step, policy_state, training_info_ta] = tf.while_loop( cond=lambda *_: True, body=self._rollout_loop_body, loop_vars=[counter, time_step, policy_state, training_info_ta], maximum_iterations=maximum_iterations, back_prop=False, name="rollout_loop") training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) training_info = nest_utils.params_to_distributions( training_info, self._training_info_spec) self._algorithm.summarize_rollout(training_info) self._algorithm.summarize_metrics() return time_step, policy_state
def per_voxel_point_sample_segment_func(data, segment_ids, num_segments, num_samples_per_voxel): """Samples features from the points within each voxel. Args: data: A tf.float32 tensor of size [N, F]. segment_ids: A tf.int32 tensor of size [N]. num_segments: Number of segments. num_samples_per_voxel: Number of features to sample per voxel. If the voxel has less number of points in it, the point features will be padded by 0. Returns: A tf.float32 tensor of size [num_segments, num_samples_per_voxel, F]. A tf.int32 indices of size [N, num_samples_per_voxel]. """ num_channels = data.get_shape().as_list()[1] if num_channels is None: raise ValueError('num_channels is None.') n = tf.shape(segment_ids)[0] def _body_fn(i, indices_range, indices): """Computes the indices of the i-th point feature in each segment.""" indices_i = tf.math.unsorted_segment_max(data=indices_range, segment_ids=segment_ids, num_segments=num_segments) indices_i_positive_mask = tf.greater(indices_i, 0) indices_i_positive = tf.boolean_mask(indices_i, indices_i_positive_mask) boolean_mask = tf.scatter_nd(indices=tf.cast(tf.expand_dims( indices_i_positive - 1, axis=1), dtype=tf.int64), updates=tf.ones_like(indices_i_positive, dtype=tf.int32), shape=(n, )) indices_range *= (1 - boolean_mask) indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32) indices_i = tf.pad(tf.expand_dims(indices_i, axis=1), paddings=[[0, 0], [i, num_samples_per_voxel - i - 1]]) indices += indices_i i = i + 1 return i, indices_range, indices cond = lambda i, indices_range, indices: i < num_samples_per_voxel (_, _, indices) = tf.while_loop( cond=cond, body=_body_fn, loop_vars=(tf.constant(0, dtype=tf.int32), tf.range(n) + 1, tf.zeros([num_segments, num_samples_per_voxel], dtype=tf.int32))) data = tf.pad(data, paddings=[[1, 0], [0, 0]]) voxel_features = tf.gather(data, tf.reshape(indices, [-1])) return tf.reshape(voxel_features, [num_segments, num_samples_per_voxel, num_channels])
def predict(self, max_num_steps, time_step, policy_state): maximum_iterations = math.ceil(max_num_steps / self._env.batch_size) [time_step, policy_state] = tf.while_loop( cond=lambda *_: True, body=self._eval_loop_body, loop_vars=[time_step, policy_state], maximum_iterations=maximum_iterations, back_prop=False, name="predict_loop") return time_step, policy_state
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure( create_ta, self._training_info_spec._replace( info=nest_utils.to_distribution_param_spec( self._training_info_spec.info))) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, next_time_step, next_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) training_info = nest_utils.params_to_distributions( training_info, self._training_info_spec) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._algorithm.summarize_train(training_info, loss_info, grads_and_vars) self._algorithm.summarize_metrics() common.get_global_counter().assign_add(1) return [next_time_step, next_state]
def fwd_fn(query_queries_fwd, query_values_fwd, support_keys_fwd, support_values_fwd, labels_fwd): """CrossTransformer forward, using a while loop to save memory.""" initial = (0, tf.zeros([tf.reduce_max(labels) + 1, zero_dim], dtype=query_queries_fwd.dtype)) def loop_body(idx, dist): dist_new = self._get_dist(query_queries_fwd[idx:idx + 1], query_values_fwd[idx:idx + 1], support_keys_fwd, support_values_fwd, labels_fwd) dist = tf.concat([dist, dist_new], axis=1) return (idx + 1, dist) _, res = tf.while_loop( lambda x, _: x < tf.shape(query_queries_fwd)[0], loop_body, initial, parallel_iterations=1) return res
def optimizer_loop( num_updates, objective_fn, update_fn, variables, first_order, clip_grad_norm, ): """Optimization of `objective_fn` for `num_updates` of `variables`.""" # Optimizer specifics. init, update, get_params = update_fn() def _body(step, preupdate_vars): """Optimization loop body.""" updated_vars = optimizer_update( iterate_collection=preupdate_vars, iteration_idx=step, objective_fn=objective_fn, update_fn=update, get_params_fn=get_params, first_order=first_order, clip_grad_norm=clip_grad_norm, ) return step + 1, updated_vars def _cond(step, *args): """Optimization truncation condition.""" del args return step < num_updates step = tf.Variable(0, trainable=False, name='inner_step_counter') loop_vars = (step, [init(var) for var in variables]) step, updated_vars = tf.while_loop(cond=_cond, body=_body, loop_vars=loop_vars, swap_memory=True) return [get_params(v) for v in updated_vars]
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval + 1, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure(create_ta, self._training_info_spec) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, time_step, policy_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP: next_time_step, policy_step, action = self._step( time_step, policy_state) next_state = policy_step.state else: policy_step = common.algorithm_step(self._algorithm.rollout, self._observation_transformer, time_step, policy_state) action = common.sample_action_distribution(policy_step.action) next_time_step = time_step next_state = policy_state action_distribution_param = common.get_distribution_params( policy_step.action) final_training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) with tape: training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, final_training_info) training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) action_distribution = nested_distributions_from_specs( self._algorithm.action_distribution_spec, training_info.action_distribution) training_info = training_info._replace( action_distribution=action_distribution) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._training_summary(training_info, loss_info, grads_and_vars) self._train_step_counter.assign_add(1) return next_time_step, next_state
def forward_pass(self, data): """Computes the test logits of MAML. Args: data: A `meta_dataset.providers.Episode` containing the data for the episode. Returns: The output logits for the query data in this episode. """ # Have to use one-hot labels since sparse softmax doesn't allow # second derivatives. support_embeddings_ = self.embedding_fn( data.support_images, self.is_training, reuse=tf.AUTO_REUSE) support_embeddings = support_embeddings_['embeddings'] embedding_vars_dict = support_embeddings_['params'] # TODO(eringrant): Refactor to make use of # `functional_backbones.linear_classifier`, which allows Gin-configuration. with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE): embedding_depth = support_embeddings.shape.as_list()[-1] fc_weights = functional_backbones.weight_variable( [embedding_depth, self.logit_dim], weight_decay=self.classifier_weight_decay) fc_bias = functional_backbones.bias_variable([self.logit_dim]) # A list of variable names, a list of corresponding Variables, and a list # of operations (possibly empty) that creates a copy of each Variable. (embedding_vars_keys, embedding_vars, embedding_vars_copy_ops) = get_embeddings_vars_copy_ops( embedding_vars_dict, make_copies=not self.is_training) # A Variable for the weights of the fc layer, a Variable for the bias of the # fc layer, and a list of operations (possibly empty) that copies them. (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops( fc_weights, fc_bias, make_copies=not self.is_training) fc_vars = [fc_weights, fc_bias] num_embedding_vars = len(embedding_vars) num_fc_vars = len(fc_vars) def _cond(step, *args): del args num_steps = self.num_update_steps if not self.is_training: num_steps += self.additional_evaluation_update_steps return step < num_steps def _body(step, *args): """The inner update loop body.""" updated_embedding_vars = args[0:num_embedding_vars] updated_fc_vars = args[num_embedding_vars:num_embedding_vars + num_fc_vars] support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['embeddings'] updated_fc_weights, updated_fc_bias = updated_fc_vars support_logits = tf.matmul(support_embeddings, updated_fc_weights) + updated_fc_bias support_logits = support_logits[:, 0:data.way] loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels, support_logits) print_op = tf.no_op() if self.debug_log: print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss]) with tf.control_dependencies([print_op]): updated_embedding_vars = gradient_descent_step( loss, updated_embedding_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] updated_fc_vars = gradient_descent_step(loss, updated_fc_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] step = step + 1 return tuple([step] + list(updated_embedding_vars) + list(updated_fc_vars)) # MAML meta updates using query set examples from an episode. if self.zero_fc_layer: # To account for variable class sizes, we initialize the output # weights to zero. See if truncated normal initialization will help. zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights)) zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias)) fc_vars_init_ops = [zero_weights_op, zero_bias_op] else: fc_vars_init_ops = fc_vars_copy_ops if self.proto_maml_fc_layer_init: support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] prototypes = metric_learners.compute_prototypes( support_embeddings, data.onehot_support_labels) pmaml_fc_weights = self.proto_maml_fc_weights( prototypes, zero_pad_to_max_way=True) pmaml_fc_bias = self.proto_maml_fc_bias( prototypes, zero_pad_to_max_way=True) fc_vars = [pmaml_fc_weights, pmaml_fc_bias] # These control dependencies assign the value of each variable to a new copy # variable that corresponds to it. This is required at test time for # initilizing the copies as they are used in place of the original vars. with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops): # Make step a local variable as we don't want to save and restore it. step = tf.Variable( 0, trainable=False, name='inner_step_counter', collections=[tf.GraphKeys.LOCAL_VARIABLES]) loop_vars = [step] + embedding_vars + fc_vars step_and_all_updated_vars = tf.while_loop( _cond, _body, loop_vars, swap_memory=True) step = step_and_all_updated_vars[0] all_updated_vars = step_and_all_updated_vars[1:] updated_embedding_vars = all_updated_vars[0:num_embedding_vars] updated_fc_weights, updated_fc_bias = all_updated_vars[ num_embedding_vars:num_embedding_vars + num_fc_vars] # Forward pass the training images with the updated weights in order to # compute the means and variances, to use for the query's batch norm. support_set_moments = None if not self.transductive_batch_norm: support_set_moments = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['moments'] query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), moments=support_set_moments, # Use support set stats for batch norm. reuse=True, backprop_through_moments=self.backprop_through_moments)['embeddings'] query_logits = (tf.matmul(query_embeddings, updated_fc_weights) + updated_fc_bias)[:, 0:data.way] return query_logits