Ejemplo n.º 1
0
  def update_state(self, inputs, outputs):
    """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
    detections_score = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_score], [-1])
    detections_class = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_class], [-1])
    num_detections = tf.shape(detections_score)[0]
    detections_instance_mask = tf.reshape(
        outputs[
            standard_fields.DetectionResultFields.instance_segments_voxel_mask],
        [num_detections, -1])
    gt_class = tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                          [-1])
    num_gt = tf.shape(gt_class)[0]
    gt_voxel_instance_ids = tf.reshape(
        inputs[standard_fields.InputDataFields.object_instance_id_voxels], [-1])
    gt_instance_masks = tf.transpose(
        tf.one_hot(gt_voxel_instance_ids - 1, depth=num_gt, dtype=tf.float32))
    for c in self.class_range:
      gt_mask_c = tf.equal(gt_class, c)
      num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
      gt_instance_masks_c = tf.boolean_mask(gt_instance_masks, gt_mask_c)
      detections_mask_c = tf.equal(detections_class, c)
      num_detections_c = tf.math.reduce_sum(
          tf.cast(detections_mask_c, dtype=tf.int32))
      if num_detections_c == 0:
        continue
      det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
      det_instance_mask_c = tf.boolean_mask(detections_instance_mask,
                                            detections_mask_c)
      det_scores_c, sorted_indices = tf.math.top_k(
          det_scores_c, k=num_detections_c)
      det_instance_mask_c = tf.gather(det_instance_mask_c, sorted_indices)
      tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
      if num_gt_c > 0:
        ious_c = instance_segmentation_utils.points_mask_iou(
            masks1=gt_instance_masks_c, masks2=det_instance_mask_c)
        max_overlap_gt_ids = tf.cast(
            tf.math.argmax(ious_c, axis=0), dtype=tf.int32)
        is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
        for i in tf.range(num_detections_c):
          gt_id = max_overlap_gt_ids[i]
          if (ious_c[gt_id, i] > self.iou_threshold and
              is_gt_box_detected[gt_id] == 0):
            tp_c = tf.maximum(
                tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c)
            is_gt_box_detected = tf.maximum(
                tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected)
      self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
      self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
      self.num_gt[c] += num_gt_c
    return tf.no_op()
Ejemplo n.º 2
0
    def _body(step, *args):
      """The inner update loop body."""
      updated_embedding_vars = args[0:num_embedding_vars]
      updated_fc_vars = args[num_embedding_vars:num_embedding_vars +
                             num_fc_vars]
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['embeddings']

      updated_fc_weights, updated_fc_bias = updated_fc_vars
      support_logits = tf.matmul(support_embeddings,
                                 updated_fc_weights) + updated_fc_bias

      support_logits = support_logits[:, 0:data.way]
      loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels,
                                             support_logits)

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss])

      with tf.control_dependencies([print_op]):
        updated_embedding_vars = gradient_descent_step(
            loss, updated_embedding_vars, self.first_order,
            self.adapt_batch_norm, self.alpha, False)['updated_vars']
        updated_fc_vars = gradient_descent_step(loss, updated_fc_vars,
                                                self.first_order,
                                                self.adapt_batch_norm,
                                                self.alpha,
                                                False)['updated_vars']

        step = step + 1
      return tuple([step] + list(updated_embedding_vars) +
                   list(updated_fc_vars))
Ejemplo n.º 3
0
    def __init__(self,
                 num_actions,
                 observation_size,
                 stack_size,
                 use_staging=True,
                 replay_capacity=1000000,
                 batch_size=32,
                 update_horizon=1,
                 gamma=1.0,
                 wrapped_memory=None):
        """Initializes a graph wrapper for the python replay memory.

    Args:
      num_actions: int, number of possible actions.
      observation_size: int, size of an input frame.
      stack_size: int, number of frames to use in state stack.
      use_staging: bool, when True it would use a staging area to prefetch the
        next sampling batch.
      replay_capacity: int, number of transitions to keep in memory.
      batch_size: int.
      update_horizon: int, length of update ('n' in n-step update).
      gamma: int, the discount factor.
      wrapped_memory: The 'inner' memory data structure. Defaults to None, which
        creates the standard DQN replay memory.

    Raises:
      ValueError: If update_horizon is not positive.
      ValueError: If discount factor is not in [0, 1].
    """
        if replay_capacity < update_horizon + 1:
            raise ValueError(
                'Update horizon (%i) should be significantly smaller '
                'than replay capacity (%i).' %
                (update_horizon, replay_capacity))
        if not update_horizon >= 1:
            raise ValueError('Update horizon must be positive.')
        if not 0.0 <= gamma <= 1.0:
            raise ValueError('Discount factor (gamma) must be in [0, 1].')

        # Allow subclasses to create self.memory.
        if wrapped_memory is not None:
            self.memory = wrapped_memory
        else:
            self.memory = OutOfGraphReplayMemory(num_actions, observation_size,
                                                 stack_size, replay_capacity,
                                                 batch_size, update_horizon,
                                                 gamma)

        with tf.name_scope('replay'):
            with tf.name_scope('add_placeholders'):
                self.add_obs_ph = tf.placeholder(tf.uint8, [observation_size],
                                                 name='add_obs_ph')
                self.add_action_ph = tf.placeholder(tf.int32, [],
                                                    name='add_action_ph')
                self.add_reward_ph = tf.placeholder(tf.float32, [],
                                                    name='add_reward_ph')
                self.add_terminal_ph = tf.placeholder(tf.uint8, [],
                                                      name='add_terminal_ph')
                self.add_legal_actions_ph = tf.placeholder(
                    tf.float32, [num_actions], name='add_legal_actions_ph')

            add_transition_ph = [
                self.add_obs_ph, self.add_action_ph, self.add_reward_ph,
                self.add_terminal_ph, self.add_legal_actions_ph
            ]

            with tf.device('/cpu:*'):
                self.add_transition_op = tf.py_func(self.memory.add,
                                                    add_transition_ph, [],
                                                    name='replay_add_py_func')

                self.transition = tf.py_func(
                    self.memory.sample_transition_batch, [], [
                        tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8,
                        tf.int32, tf.float32
                    ],
                    name='replay_sample_py_func')

                if use_staging:
                    # To hide the py_func latency use a staging area to pre-fetch the next
                    # batch of transitions.
                    (states, actions, rewards, next_states, terminals, indices,
                     next_legal_actions) = self.transition
                    # StagingArea requires all the shapes to be defined.
                    states.set_shape(
                        [batch_size, observation_size, stack_size])
                    actions.set_shape([batch_size])
                    rewards.set_shape([batch_size])
                    next_states.set_shape(
                        [batch_size, observation_size, stack_size])
                    terminals.set_shape([batch_size])
                    indices.set_shape([batch_size])
                    next_legal_actions.set_shape([batch_size, num_actions])

                    # Create the staging area in CPU.
                    prefetch_area = tf.contrib.staging.StagingArea([
                        tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8,
                        tf.int32, tf.float32
                    ])

                    self.prefetch_batch = prefetch_area.put(
                        (states, actions, rewards, next_states, terminals,
                         indices, next_legal_actions))
                else:
                    self.prefetch_batch = tf.no_op()

            if use_staging:
                # Get the sample_transition_batch in GPU. This would do the copy from
                # CPU to GPU.
                self.transition = prefetch_area.get()

            (self.states, self.actions, self.rewards, self.next_states,
             self.terminals, self.indices,
             self.next_legal_actions) = self.transition

            # Since these are py_func tensors, no information about their shape is
            # present. Setting the shape only for the necessary tensors
            self.states.set_shape([None, observation_size, stack_size])
            self.next_states.set_shape([None, observation_size, stack_size])
Ejemplo n.º 4
0
    def update_state(self, inputs, outputs):
        """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
        detections_score = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_score], [-1])
        detections_class = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_class], [-1])
        detections_length = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_length],
            [-1])
        detections_height = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_height],
            [-1])
        detections_width = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_width], [-1])
        detections_center = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_center],
            [-1, 3])
        detections_rotation_matrix = tf.reshape(
            outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            [-1, 3, 3])
        gt_class = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_class], [-1])
        gt_length = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_length], [-1])
        gt_height = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_height], [-1])
        gt_width = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_width], [-1])
        gt_center = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_center], [-1, 3])
        gt_rotation_matrix = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_rotation_matrix],
            [-1, 3, 3])
        for c in self.class_range:
            gt_mask_c = tf.equal(gt_class, c)
            num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
            gt_length_c = tf.boolean_mask(gt_length, gt_mask_c)
            gt_height_c = tf.boolean_mask(gt_height, gt_mask_c)
            gt_width_c = tf.boolean_mask(gt_width, gt_mask_c)
            gt_center_c = tf.boolean_mask(gt_center, gt_mask_c)
            gt_rotation_matrix_c = tf.boolean_mask(gt_rotation_matrix,
                                                   gt_mask_c)
            detections_mask_c = tf.equal(detections_class, c)
            num_detections_c = tf.math.reduce_sum(
                tf.cast(detections_mask_c, dtype=tf.int32))
            if num_detections_c == 0:
                continue
            det_length_c = tf.boolean_mask(detections_length,
                                           detections_mask_c)
            det_height_c = tf.boolean_mask(detections_height,
                                           detections_mask_c)
            det_width_c = tf.boolean_mask(detections_width, detections_mask_c)
            det_center_c = tf.boolean_mask(detections_center,
                                           detections_mask_c)
            det_rotation_matrix_c = tf.boolean_mask(detections_rotation_matrix,
                                                    detections_mask_c)
            det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
            det_scores_c, sorted_indices = tf.math.top_k(det_scores_c,
                                                         k=num_detections_c)
            det_length_c = tf.gather(det_length_c, sorted_indices)
            det_height_c = tf.gather(det_height_c, sorted_indices)
            det_width_c = tf.gather(det_width_c, sorted_indices)
            det_center_c = tf.gather(det_center_c, sorted_indices)
            det_rotation_matrix_c = tf.gather(det_rotation_matrix_c,
                                              sorted_indices)
            tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
            if num_gt_c > 0:
                ious_c = box_ops.iou3d(
                    boxes1_length=gt_length_c,
                    boxes1_height=gt_height_c,
                    boxes1_width=gt_width_c,
                    boxes1_center=gt_center_c,
                    boxes1_rotation_matrix=gt_rotation_matrix_c,
                    boxes2_length=det_length_c,
                    boxes2_height=det_height_c,
                    boxes2_width=det_width_c,
                    boxes2_center=det_center_c,
                    boxes2_rotation_matrix=det_rotation_matrix_c)
                max_overlap_gt_ids = tf.cast(tf.math.argmax(ious_c, axis=0),
                                             dtype=tf.int32)
                is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
                for i in tf.range(num_detections_c):
                    gt_id = max_overlap_gt_ids[i]
                    if (ious_c[gt_id, i] > self.iou_threshold
                            and is_gt_box_detected[gt_id] == 0):
                        tp_c = tf.maximum(
                            tf.one_hot(i, num_detections_c, dtype=tf.int32),
                            tp_c)
                        is_gt_box_detected = tf.maximum(
                            tf.one_hot(gt_id, num_gt_c, dtype=tf.int32),
                            is_gt_box_detected)
            self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
            self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
            self.num_gt[c] += num_gt_c
        return tf.no_op()
Ejemplo n.º 5
0
  def compute_logits(self, data):
    """Computes the class logits for the episode.

    Args:
      data: A `meta_dataset.providers.Episode`.

    Returns:
      The query set logits as a [num_query_images, way] matrix.

    Raises:
      ValueError: Distance must be one of l2 or cosine.
    """
    # ------------------------ Finetuning -------------------------------
    # Possibly make copies of embedding variables, if they will get modified.
    # This is for making temporary-only updates to the embedding network
    # which will not persist after the end of the episode.
    make_copies = self.finetune_all_layers

    # TODO(eringrant): Reduce the number of times the embedding function graph
    # is built with the same input.
    support_embeddings_params_moments = self.embedding_fn(
        data.support_images, self.is_training)
    support_embeddings = support_embeddings_params_moments['embeddings']
    support_embeddings_var_dict = support_embeddings_params_moments['params']

    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         support_embeddings_var_dict, make_copies)
    embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops)

    # Compute the initial training loss (only for printing purposes). This
    # line is also needed for adding the fc variables to the graph so that the
    # tf.all_variables() line below detects them.
    logits = self._fc_layer(support_embeddings)[:, 0:data.way]
    finetune_loss = self.compute_loss(
        onehot_labels=data.onehot_support_labels,
        predictions=logits,
    )

    # Decide which variables to finetune.
    fc_vars, vars_to_finetune = [], []
    for var in tf.trainable_variables():
      if 'fc_finetune' in var.name:
        fc_vars.append(var)
        vars_to_finetune.append(var)
    if self.finetune_all_layers:
      vars_to_finetune.extend(embedding_vars)
    logging.info('Finetuning will optimize variables: %s', vars_to_finetune)

    for i in range(self.num_finetune_steps):
      if i == 0:
        # Randomly initialize the fc layer.
        fc_reset = tf.variables_initializer(var_list=fc_vars)
        # Adam related variables are created when minimize() is called.
        # We create an unused op here to put all adam varariables under
        # the 'adam_opt' namescope and create a reset op to reinitialize
        # these variables before the first finetune step.
        adam_reset = tf.no_op()
        if self.finetune_with_adam:
          with tf.variable_scope('adam_opt'):
            unused_op = self.finetune_opt.minimize(
                finetune_loss, var_list=vars_to_finetune)
          adam_reset = tf.variables_initializer(self.finetune_opt.variables())
        with tf.control_dependencies(
            [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] +
            vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:',
                finetune_loss
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

      else:
        with tf.control_dependencies([finetune_op, finetune_loss] +
                                     vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i,
                vars_to_finetune[0][0, 0],
                'loss:',
                finetune_loss,
                'accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_support_labels, predictions=logits),
                'query accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_query_labels, predictions=query_logits),
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

    # Finetuning is now over, compute the query performance using the updated
    # fc layer, and possibly the updated embedding network.
    with tf.control_dependencies([finetune_op] + vars_to_finetune):
      query_embeddings = self.embedding_fn(
          data.query_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']
      query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

      if self.debug_log:
        # The train logits are computed only for printing.
        support_embeddings = self.embedding_fn(
            data.support_images,
            self.is_training,
            params=collections.OrderedDict(
                zip(embedding_vars_keys, embedding_vars)),
            reuse=True)['embeddings']
        logits = self._fc_layer(support_embeddings)[:, 0:data.way]

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print([
            'accuracy:',
            self.compute_accuracy(
                labels=data.onehot_support_labels, predictions=logits),
            'query accuracy:',
            self.compute_accuracy(
                labels=data.onehot_query_labels, predictions=query_logits),
        ])
      with tf.control_dependencies([print_op]):
        query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

    return query_logits
Ejemplo n.º 6
0
 def no_op_initialization(onehot_labels, embeddings, *vbls):
     del onehot_labels
     del embeddings
     del vbls
     return [tf.no_op()]
Ejemplo n.º 7
0
    def forward_pass(self, data):
        """Computes the query logits for the given episode `data`."""

        if self.film_init == 'scratch':
            self.film_selector = None
        elif self.film_init == 'imagenet':
            # Note: this makes the assumption that the first set of learned FiLM
            # parameters corresponds to the ImageNet dataset. Otherwise, the
            # following line should be changed appropriately.
            self.film_selector = 0
        elif self.film_init in ['blender', 'blender_hard']:
            dataset_logits = functional_backbones.dataset_classifier(
                data.support_images)
            if self.film_init == 'blender_hard':
                # Select only the argmax entry.
                self.film_selector = tf.one_hot(
                    tf.math.argmax(dataset_logits, axis=-1),
                    depth=tf.shape(dataset_logits)[1])
            else:
                # Take a convex combination.
                self.film_selector = tf.nn.softmax(dataset_logits, axis=-1)

        if self.num_steps:
            # Initial forward pass, required for the `unused_op` below and for placing
            # variables in tf.trainable_variables() for the below block to pick up.
            loss = self._compute_losses(data, compute_on_query=False)['loss']

            # Pick out the variables to optimize.
            self.opt_vars = []
            for var in tf.trainable_variables():
                if '_for_film_learner' in var.name:
                    self.opt_vars.append(var)
            tf.logging.info('FiLMLearner will optimize vars: {}'.format(
                self.opt_vars))

        for i in range(self.num_steps):
            if i == 0:
                # Re-initialize the variables to optimize for the new episode, to ensure
                # the FiLM parameters aren't re-used across tasks of a given dataset.
                vars_reset = tf.variables_initializer(var_list=self.opt_vars)
                # Adam related variables are created when minimize() is called.
                # We create an unused op here to put all adam varariables under
                # the 'adam_opt' namescope and create a reset op to reinitialize
                # these variables before the first finetune step.
                with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE):
                    unused_op = self.opt.minimize(loss, var_list=self.opt_vars)
                adam_reset = tf.variables_initializer(self.opt.variables())

                with tf.control_dependencies([vars_reset, adam_reset, loss] +
                                             self.opt_vars):
                    print_op = tf.no_op()
                    if self.debug_log:
                        print_op = tf.print([
                            'step: %d' % i, self.opt_vars[0][0], 'loss:', loss
                        ],
                                            summarize=-1)

                    with tf.control_dependencies([print_op]):
                        # Get the train op.
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

            else:
                with tf.control_dependencies([train_op, loss, acc] +
                                             self.opt_vars +
                                             [query_loss, query_acc] *
                                             int(self.debug_log)):

                    print_op = tf.no_op()
                    if self.debug_log:
                        print_list = [
                            '################',
                            'step: %d' % i,
                            self.opt_vars[0][0],
                            'support loss:',
                            loss,
                            'query loss:',
                            query_loss,
                            'support acc:',
                            acc,
                            'query acc:',
                            query_acc,
                        ]
                        print_op = tf.print(print_list)

                    with tf.control_dependencies([print_op]):
                        # Get the train op (the loss is returned just for printing).
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

        # Training is now over, compute the final query logits.
        dependency_list = [] if not self.num_steps else [train_op
                                                         ] + self.opt_vars
        with tf.control_dependencies(dependency_list):
            results = self._compute_losses(data, compute_on_query=True)
            (loss, query_loss, query_logits, acc,
             query_acc) = (results['loss'], results['query_loss'],
                           results['query_logits'], results['acc'],
                           results['query_acc'])

            print_op = tf.no_op()
            if self.debug_log:
                print_op = tf.print([
                    'Done training',
                    'support loss:',
                    loss,
                    'query loss:',
                    query_loss,
                    'support acc:',
                    acc,
                    'query acc:',
                    query_acc,
                ])
            with tf.control_dependencies([print_op]):
                query_logits = tf.identity(query_logits)

        return query_logits