Exemple #1
0
    def train(self, max_num_steps, time_step, policy_state):
        """Perform on-policy training with `max_num_steps`.

        Args:
            max_num_steps (int): stops after so many environment steps. Is the
                total number of steps from all the individual environment in
                the bached environments including StepType.LAST steps.
            time_step (ActionTimeStep): optional initial time_step. If None, it
                will use self.get_initial_time_step(). Elements should be shape
                [batch_size, ...].
            policy_state (nested Tensor): optional initial state for the policy.
        Returns:
            None
        """
        maximum_iterations = math.ceil(
            max_num_steps /
            (self._env.batch_size *
             (self._train_interval +
              (self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP))))
        [time_step,
         policy_state] = tf.while_loop(cond=lambda *_: True,
                                       body=self._iter,
                                       loop_vars=[time_step, policy_state],
                                       maximum_iterations=maximum_iterations,
                                       back_prop=False,
                                       name="driver_loop")
        return time_step, policy_state
Exemple #2
0
def em_loop(
    num_updates,
    e_step,
    m_step,
    variables,
):
    """Expectation-maximization of objective_fn wrt variables for num_updates."""
    def _body(step, preupdate_vars):
        train_predictions_, responsibilities_ = e_step(preupdate_vars)
        updated_vars = m_step(preupdate_vars, train_predictions_,
                              responsibilities_)
        return step + 1, updated_vars

    def _cond(step, *args):
        del args
        return step < num_updates

    step = tf.Variable(0, trainable=False, name='inner_step_counter')
    loop_vars = (step, variables)
    step, updated_vars = tf.while_loop(cond=_cond,
                                       body=_body,
                                       loop_vars=loop_vars,
                                       swap_memory=True)

    return updated_vars
    def grad_fn(dy):
      """Compute gradients using a while loop to save memory."""
      support_keys_id = tf.identity(support_keys)
      support_values_id = tf.identity(support_values)
      initial = (0, tf.zeros(tf.shape(query_queries)[1:],
                             dtype=dy.dtype)[tf.newaxis, :][:zero_dim],
                 tf.zeros(tf.shape(query_values)[1:],
                          dtype=dy.dtype)[tf.newaxis, :][:zero_dim],
                 tf.zeros(tf.shape(support_keys_id), dtype=dy.dtype),
                 tf.zeros(tf.shape(support_values_id), dtype=dy.dtype))

      def loop_body(idx, qq_grad, qv_grad, sk_grad, sv_grad):
        """Compute gradients for a single query."""
        qq = query_queries[idx:idx + 1]
        qv = query_values[idx:idx + 1]
        x = self._get_dist(qq, qv, support_keys_id, support_values_id, labels)
        grads = tf.gradients(
            x, [qq, qv, support_keys_id, support_values_id],
            grad_ys=dy[:, idx:idx + 1])
        qq_grad = tf.concat([qq_grad, grads[0]], axis=0)
        qv_grad = tf.concat([qv_grad, grads[1]], axis=0)
        sk_grad += grads[2]
        sv_grad += grads[3]
        return (idx + 1, qq_grad, qv_grad, sk_grad, sv_grad)

      agg_grads = tf.while_loop(
          lambda *arg: arg[0] < tf.shape(query_queries)[0],
          loop_body,
          initial,
          parallel_iterations=1)
      return agg_grads[1:] + (None,)
Exemple #4
0
    def rollout(self, max_num_steps, time_step, policy_state):
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size
        maximum_iterations = math.ceil(max_num_steps / self._env.batch_size)

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=maximum_iterations,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(
            create_ta,
            self._training_info_spec._replace(
                rollout_info=nest_utils.to_distribution_param_spec(
                    self._training_info_spec.rollout_info)))

        [counter, time_step, policy_state, training_info_ta] = tf.while_loop(
            cond=lambda *_: True,
            body=self._rollout_loop_body,
            loop_vars=[counter, time_step, policy_state, training_info_ta],
            maximum_iterations=maximum_iterations,
            back_prop=False,
            name="rollout_loop")

        training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                              training_info_ta)

        training_info = nest_utils.params_to_distributions(
            training_info, self._training_info_spec)

        self._algorithm.summarize_rollout(training_info)
        self._algorithm.summarize_metrics()

        return time_step, policy_state
Exemple #5
0
def per_voxel_point_sample_segment_func(data, segment_ids, num_segments,
                                        num_samples_per_voxel):
    """Samples features from the points within each voxel.

  Args:
    data: A tf.float32 tensor of size [N, F].
    segment_ids: A tf.int32 tensor of size [N].
    num_segments: Number of segments.
    num_samples_per_voxel: Number of features to sample per voxel. If the voxel
      has less number of points in it, the point features will be padded by 0.

  Returns:
    A tf.float32 tensor of size [num_segments, num_samples_per_voxel, F].
    A tf.int32 indices of size [N, num_samples_per_voxel].
  """
    num_channels = data.get_shape().as_list()[1]
    if num_channels is None:
        raise ValueError('num_channels is None.')
    n = tf.shape(segment_ids)[0]

    def _body_fn(i, indices_range, indices):
        """Computes the indices of the i-th point feature in each segment."""
        indices_i = tf.math.unsorted_segment_max(data=indices_range,
                                                 segment_ids=segment_ids,
                                                 num_segments=num_segments)
        indices_i_positive_mask = tf.greater(indices_i, 0)
        indices_i_positive = tf.boolean_mask(indices_i,
                                             indices_i_positive_mask)
        boolean_mask = tf.scatter_nd(indices=tf.cast(tf.expand_dims(
            indices_i_positive - 1, axis=1),
                                                     dtype=tf.int64),
                                     updates=tf.ones_like(indices_i_positive,
                                                          dtype=tf.int32),
                                     shape=(n, ))
        indices_range *= (1 - boolean_mask)
        indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32)
        indices_i = tf.pad(tf.expand_dims(indices_i, axis=1),
                           paddings=[[0, 0],
                                     [i, num_samples_per_voxel - i - 1]])
        indices += indices_i
        i = i + 1
        return i, indices_range, indices

    cond = lambda i, indices_range, indices: i < num_samples_per_voxel

    (_, _, indices) = tf.while_loop(
        cond=cond,
        body=_body_fn,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.range(n) + 1,
                   tf.zeros([num_segments, num_samples_per_voxel],
                            dtype=tf.int32)))

    data = tf.pad(data, paddings=[[1, 0], [0, 0]])
    voxel_features = tf.gather(data, tf.reshape(indices, [-1]))
    return tf.reshape(voxel_features,
                      [num_segments, num_samples_per_voxel, num_channels])
Exemple #6
0
 def predict(self, max_num_steps, time_step, policy_state):
     maximum_iterations = math.ceil(max_num_steps / self._env.batch_size)
     [time_step, policy_state] = tf.while_loop(
         cond=lambda *_: True,
         body=self._eval_loop_body,
         loop_vars=[time_step, policy_state],
         maximum_iterations=maximum_iterations,
         back_prop=False,
         name="predict_loop")
     return time_step, policy_state
Exemple #7
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(
            create_ta,
            self._training_info_spec._replace(
                info=nest_utils.to_distribution_param_spec(
                    self._training_info_spec.info)))

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, next_time_step, next_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            training_info = nest_utils.params_to_distributions(
                training_info, self._training_info_spec)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._algorithm.summarize_train(training_info, loss_info,
                                        grads_and_vars)
        self._algorithm.summarize_metrics()

        common.get_global_counter().assign_add(1)

        return [next_time_step, next_state]
    def fwd_fn(query_queries_fwd, query_values_fwd, support_keys_fwd,
               support_values_fwd, labels_fwd):
      """CrossTransformer forward, using a while loop to save memory."""
      initial = (0,
                 tf.zeros([tf.reduce_max(labels) + 1, zero_dim],
                          dtype=query_queries_fwd.dtype))

      def loop_body(idx, dist):
        dist_new = self._get_dist(query_queries_fwd[idx:idx + 1],
                                  query_values_fwd[idx:idx + 1],
                                  support_keys_fwd, support_values_fwd,
                                  labels_fwd)
        dist = tf.concat([dist, dist_new], axis=1)
        return (idx + 1, dist)

      _, res = tf.while_loop(
          lambda x, _: x < tf.shape(query_queries_fwd)[0],
          loop_body,
          initial,
          parallel_iterations=1)
      return res
Exemple #9
0
def optimizer_loop(
    num_updates,
    objective_fn,
    update_fn,
    variables,
    first_order,
    clip_grad_norm,
):
    """Optimization of `objective_fn` for `num_updates` of `variables`."""

    # Optimizer specifics.
    init, update, get_params = update_fn()

    def _body(step, preupdate_vars):
        """Optimization loop body."""
        updated_vars = optimizer_update(
            iterate_collection=preupdate_vars,
            iteration_idx=step,
            objective_fn=objective_fn,
            update_fn=update,
            get_params_fn=get_params,
            first_order=first_order,
            clip_grad_norm=clip_grad_norm,
        )

        return step + 1, updated_vars

    def _cond(step, *args):
        """Optimization truncation condition."""
        del args
        return step < num_updates

    step = tf.Variable(0, trainable=False, name='inner_step_counter')
    loop_vars = (step, [init(var) for var in variables])
    step, updated_vars = tf.while_loop(cond=_cond,
                                       body=_body,
                                       loop_vars=loop_vars,
                                       swap_memory=True)

    return [get_params(v) for v in updated_vars]
Exemple #10
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval + 1,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(create_ta,
                                                 self._training_info_spec)

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, time_step, policy_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

        if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP:
            next_time_step, policy_step, action = self._step(
                time_step, policy_state)
            next_state = policy_step.state
        else:
            policy_step = common.algorithm_step(self._algorithm.rollout,
                                                self._observation_transformer,
                                                time_step, policy_state)
            action = common.sample_action_distribution(policy_step.action)
            next_time_step = time_step
            next_state = policy_state

        action_distribution_param = common.get_distribution_params(
            policy_step.action)

        final_training_info = make_training_info(
            action_distribution=action_distribution_param,
            action=action,
            reward=time_step.reward,
            discount=time_step.discount,
            step_type=time_step.step_type,
            info=policy_step.info)

        with tape:
            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                final_training_info)
            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            action_distribution = nested_distributions_from_specs(
                self._algorithm.action_distribution_spec,
                training_info.action_distribution)

            training_info = training_info._replace(
                action_distribution=action_distribution)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._training_summary(training_info, loss_info, grads_and_vars)

        self._train_step_counter.assign_add(1)

        return next_time_step, next_state
Exemple #11
0
  def forward_pass(self, data):
    """Computes the test logits of MAML.

    Args:
      data: A `meta_dataset.providers.Episode` containing the data for the
        episode.

    Returns:
      The output logits for the query data in this episode.
    """
    # Have to use one-hot labels since sparse softmax doesn't allow
    # second derivatives.
    support_embeddings_ = self.embedding_fn(
        data.support_images, self.is_training, reuse=tf.AUTO_REUSE)
    support_embeddings = support_embeddings_['embeddings']
    embedding_vars_dict = support_embeddings_['params']

    # TODO(eringrant): Refactor to make use of
    # `functional_backbones.linear_classifier`, which allows Gin-configuration.
    with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE):
      embedding_depth = support_embeddings.shape.as_list()[-1]
      fc_weights = functional_backbones.weight_variable(
          [embedding_depth, self.logit_dim],
          weight_decay=self.classifier_weight_decay)
      fc_bias = functional_backbones.bias_variable([self.logit_dim])

    # A list of variable names, a list of corresponding Variables, and a list
    # of operations (possibly empty) that creates a copy of each Variable.
    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         embedding_vars_dict, make_copies=not self.is_training)

    # A Variable for the weights of the fc layer, a Variable for the bias of the
    # fc layer, and a list of operations (possibly empty) that copies them.
    (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops(
        fc_weights, fc_bias, make_copies=not self.is_training)

    fc_vars = [fc_weights, fc_bias]
    num_embedding_vars = len(embedding_vars)
    num_fc_vars = len(fc_vars)

    def _cond(step, *args):
      del args
      num_steps = self.num_update_steps
      if not self.is_training:
        num_steps += self.additional_evaluation_update_steps
      return step < num_steps

    def _body(step, *args):
      """The inner update loop body."""
      updated_embedding_vars = args[0:num_embedding_vars]
      updated_fc_vars = args[num_embedding_vars:num_embedding_vars +
                             num_fc_vars]
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['embeddings']

      updated_fc_weights, updated_fc_bias = updated_fc_vars
      support_logits = tf.matmul(support_embeddings,
                                 updated_fc_weights) + updated_fc_bias

      support_logits = support_logits[:, 0:data.way]
      loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels,
                                             support_logits)

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss])

      with tf.control_dependencies([print_op]):
        updated_embedding_vars = gradient_descent_step(
            loss, updated_embedding_vars, self.first_order,
            self.adapt_batch_norm, self.alpha, False)['updated_vars']
        updated_fc_vars = gradient_descent_step(loss, updated_fc_vars,
                                                self.first_order,
                                                self.adapt_batch_norm,
                                                self.alpha,
                                                False)['updated_vars']

        step = step + 1
      return tuple([step] + list(updated_embedding_vars) +
                   list(updated_fc_vars))

    # MAML meta updates using query set examples from an episode.
    if self.zero_fc_layer:
      # To account for variable class sizes, we initialize the output
      # weights to zero. See if truncated normal initialization will help.
      zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights))
      zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias))
      fc_vars_init_ops = [zero_weights_op, zero_bias_op]
    else:
      fc_vars_init_ops = fc_vars_copy_ops

    if self.proto_maml_fc_layer_init:
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']

      prototypes = metric_learners.compute_prototypes(
          support_embeddings, data.onehot_support_labels)
      pmaml_fc_weights = self.proto_maml_fc_weights(
          prototypes, zero_pad_to_max_way=True)
      pmaml_fc_bias = self.proto_maml_fc_bias(
          prototypes, zero_pad_to_max_way=True)
      fc_vars = [pmaml_fc_weights, pmaml_fc_bias]

    # These control dependencies assign the value of each variable to a new copy
    # variable that corresponds to it. This is required at test time for
    # initilizing the copies as they are used in place of the original vars.
    with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops):
      # Make step a local variable as we don't want to save and restore it.
      step = tf.Variable(
          0,
          trainable=False,
          name='inner_step_counter',
          collections=[tf.GraphKeys.LOCAL_VARIABLES])
      loop_vars = [step] + embedding_vars + fc_vars
      step_and_all_updated_vars = tf.while_loop(
          _cond, _body, loop_vars, swap_memory=True)
      step = step_and_all_updated_vars[0]
      all_updated_vars = step_and_all_updated_vars[1:]
      updated_embedding_vars = all_updated_vars[0:num_embedding_vars]
      updated_fc_weights, updated_fc_bias = all_updated_vars[
          num_embedding_vars:num_embedding_vars + num_fc_vars]

    # Forward pass the training images with the updated weights in order to
    # compute the means and variances, to use for the query's batch norm.
    support_set_moments = None
    if not self.transductive_batch_norm:
      support_set_moments = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['moments']

    query_embeddings = self.embedding_fn(
        data.query_images,
        self.is_training,
        params=collections.OrderedDict(
            zip(embedding_vars_keys, updated_embedding_vars)),
        moments=support_set_moments,  # Use support set stats for batch norm.
        reuse=True,
        backprop_through_moments=self.backprop_through_moments)['embeddings']

    query_logits = (tf.matmul(query_embeddings, updated_fc_weights) +
                    updated_fc_bias)[:, 0:data.way]

    return query_logits