Esempio n. 1
0
    def _body(step, *args):
      """The inner update loop body."""
      updated_embedding_vars = args[0:num_embedding_vars]
      updated_fc_vars = args[num_embedding_vars:num_embedding_vars +
                             num_fc_vars]
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['embeddings']

      updated_fc_weights, updated_fc_bias = updated_fc_vars
      support_logits = tf.matmul(support_embeddings,
                                 updated_fc_weights) + updated_fc_bias

      support_logits = support_logits[:, 0:data.way]
      loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels,
                                             support_logits)

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss])

      with tf.control_dependencies([print_op]):
        updated_embedding_vars = gradient_descent_step(
            loss, updated_embedding_vars, self.first_order,
            self.adapt_batch_norm, self.alpha, False)['updated_vars']
        updated_fc_vars = gradient_descent_step(loss, updated_fc_vars,
                                                self.first_order,
                                                self.adapt_batch_norm,
                                                self.alpha,
                                                False)['updated_vars']

        step = step + 1
      return tuple([step] + list(updated_embedding_vars) +
                   list(updated_fc_vars))
Esempio n. 2
0
    def forward_pass(self, data):
        """Computes the query logits for the given episode `data`."""

        if self.film_init == 'scratch':
            self.film_selector = None
        elif self.film_init == 'imagenet':
            # Note: this makes the assumption that the first set of learned FiLM
            # parameters corresponds to the ImageNet dataset. Otherwise, the
            # following line should be changed appropriately.
            self.film_selector = 0
        elif self.film_init in ['blender', 'blender_hard']:
            dataset_logits = functional_backbones.dataset_classifier(
                data.support_images)
            if self.film_init == 'blender_hard':
                # Select only the argmax entry.
                self.film_selector = tf.one_hot(
                    tf.math.argmax(dataset_logits, axis=-1),
                    depth=tf.shape(dataset_logits)[1])
            else:
                # Take a convex combination.
                self.film_selector = tf.nn.softmax(dataset_logits, axis=-1)

        if self.num_steps:
            # Initial forward pass, required for the `unused_op` below and for placing
            # variables in tf.trainable_variables() for the below block to pick up.
            loss = self._compute_losses(data, compute_on_query=False)['loss']

            # Pick out the variables to optimize.
            self.opt_vars = []
            for var in tf.trainable_variables():
                if '_for_film_learner' in var.name:
                    self.opt_vars.append(var)
            tf.logging.info('FiLMLearner will optimize vars: {}'.format(
                self.opt_vars))

        for i in range(self.num_steps):
            if i == 0:
                # Re-initialize the variables to optimize for the new episode, to ensure
                # the FiLM parameters aren't re-used across tasks of a given dataset.
                vars_reset = tf.variables_initializer(var_list=self.opt_vars)
                # Adam related variables are created when minimize() is called.
                # We create an unused op here to put all adam varariables under
                # the 'adam_opt' namescope and create a reset op to reinitialize
                # these variables before the first finetune step.
                with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE):
                    unused_op = self.opt.minimize(loss, var_list=self.opt_vars)
                adam_reset = tf.variables_initializer(self.opt.variables())

                with tf.control_dependencies([vars_reset, adam_reset, loss] +
                                             self.opt_vars):
                    print_op = tf.no_op()
                    if self.debug_log:
                        print_op = tf.print([
                            'step: %d' % i, self.opt_vars[0][0], 'loss:', loss
                        ],
                                            summarize=-1)

                    with tf.control_dependencies([print_op]):
                        # Get the train op.
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

            else:
                with tf.control_dependencies([train_op, loss, acc] +
                                             self.opt_vars +
                                             [query_loss, query_acc] *
                                             int(self.debug_log)):

                    print_op = tf.no_op()
                    if self.debug_log:
                        print_list = [
                            '################',
                            'step: %d' % i,
                            self.opt_vars[0][0],
                            'support loss:',
                            loss,
                            'query loss:',
                            query_loss,
                            'support acc:',
                            acc,
                            'query acc:',
                            query_acc,
                        ]
                        print_op = tf.print(print_list)

                    with tf.control_dependencies([print_op]):
                        # Get the train op (the loss is returned just for printing).
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

        # Training is now over, compute the final query logits.
        dependency_list = [] if not self.num_steps else [train_op
                                                         ] + self.opt_vars
        with tf.control_dependencies(dependency_list):
            results = self._compute_losses(data, compute_on_query=True)
            (loss, query_loss, query_logits, acc,
             query_acc) = (results['loss'], results['query_loss'],
                           results['query_logits'], results['acc'],
                           results['query_acc'])

            print_op = tf.no_op()
            if self.debug_log:
                print_op = tf.print([
                    'Done training',
                    'support loss:',
                    loss,
                    'query loss:',
                    query_loss,
                    'support acc:',
                    acc,
                    'query acc:',
                    query_acc,
                ])
            with tf.control_dependencies([print_op]):
                query_logits = tf.identity(query_logits)

        return query_logits
Esempio n. 3
0
  def compute_logits(self, data):
    """Computes the class logits for the episode.

    Args:
      data: A `meta_dataset.providers.Episode`.

    Returns:
      The query set logits as a [num_query_images, way] matrix.

    Raises:
      ValueError: Distance must be one of l2 or cosine.
    """
    # ------------------------ Finetuning -------------------------------
    # Possibly make copies of embedding variables, if they will get modified.
    # This is for making temporary-only updates to the embedding network
    # which will not persist after the end of the episode.
    make_copies = self.finetune_all_layers

    # TODO(eringrant): Reduce the number of times the embedding function graph
    # is built with the same input.
    support_embeddings_params_moments = self.embedding_fn(
        data.support_images, self.is_training)
    support_embeddings = support_embeddings_params_moments['embeddings']
    support_embeddings_var_dict = support_embeddings_params_moments['params']

    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         support_embeddings_var_dict, make_copies)
    embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops)

    # Compute the initial training loss (only for printing purposes). This
    # line is also needed for adding the fc variables to the graph so that the
    # tf.all_variables() line below detects them.
    logits = self._fc_layer(support_embeddings)[:, 0:data.way]
    finetune_loss = self.compute_loss(
        onehot_labels=data.onehot_support_labels,
        predictions=logits,
    )

    # Decide which variables to finetune.
    fc_vars, vars_to_finetune = [], []
    for var in tf.trainable_variables():
      if 'fc_finetune' in var.name:
        fc_vars.append(var)
        vars_to_finetune.append(var)
    if self.finetune_all_layers:
      vars_to_finetune.extend(embedding_vars)
    logging.info('Finetuning will optimize variables: %s', vars_to_finetune)

    for i in range(self.num_finetune_steps):
      if i == 0:
        # Randomly initialize the fc layer.
        fc_reset = tf.variables_initializer(var_list=fc_vars)
        # Adam related variables are created when minimize() is called.
        # We create an unused op here to put all adam varariables under
        # the 'adam_opt' namescope and create a reset op to reinitialize
        # these variables before the first finetune step.
        adam_reset = tf.no_op()
        if self.finetune_with_adam:
          with tf.variable_scope('adam_opt'):
            unused_op = self.finetune_opt.minimize(
                finetune_loss, var_list=vars_to_finetune)
          adam_reset = tf.variables_initializer(self.finetune_opt.variables())
        with tf.control_dependencies(
            [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] +
            vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:',
                finetune_loss
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

      else:
        with tf.control_dependencies([finetune_op, finetune_loss] +
                                     vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i,
                vars_to_finetune[0][0, 0],
                'loss:',
                finetune_loss,
                'accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_support_labels, predictions=logits),
                'query accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_query_labels, predictions=query_logits),
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

    # Finetuning is now over, compute the query performance using the updated
    # fc layer, and possibly the updated embedding network.
    with tf.control_dependencies([finetune_op] + vars_to_finetune):
      query_embeddings = self.embedding_fn(
          data.query_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']
      query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

      if self.debug_log:
        # The train logits are computed only for printing.
        support_embeddings = self.embedding_fn(
            data.support_images,
            self.is_training,
            params=collections.OrderedDict(
                zip(embedding_vars_keys, embedding_vars)),
            reuse=True)['embeddings']
        logits = self._fc_layer(support_embeddings)[:, 0:data.way]

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print([
            'accuracy:',
            self.compute_accuracy(
                labels=data.onehot_support_labels, predictions=logits),
            'query accuracy:',
            self.compute_accuracy(
                labels=data.onehot_query_labels, predictions=query_logits),
        ])
      with tf.control_dependencies([print_op]):
        query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

    return query_logits
  def call(self, inputs, training=True):
    voxel_features, voxel_xyz_indices, num_valid_voxels = inputs
    voxel_features_list = [voxel_features]
    voxel_xyz_indices_list = [voxel_xyz_indices]
    num_valid_voxels_list = [num_valid_voxels]
    index_mapping_list = []

    if training:
      tf.print('num_valid_voxels', num_valid_voxels)

    # Encoder
    for level in range(self.num_levels):
      inputs_i = (voxel_features_list[-1], voxel_xyz_indices_list[-1],
                  num_valid_voxels_list[-1])
      conv_block_i = getattr(self, 'encoder_' + str(level))
      outputs_i = conv_block_i(inputs_i, training)
      (pooled_voxel_features, pooled_voxel_indices, num_valid_pooled_voxels,
       index_mapping) = sparse_voxel_net_utils.voxel_pooling(
           voxel_features=outputs_i,
           voxel_xyz_indices=inputs_i[1],
           num_valid_voxels=inputs_i[2],
           pooling_size=(2, 2, 2),
           segment_func=self.network_pooling_segment_func)
      voxel_features_list.append(pooled_voxel_features)
      voxel_xyz_indices_list.append(pooled_voxel_indices)
      num_valid_voxels_list.append(num_valid_pooled_voxels)
      index_mapping_list.append(index_mapping)

    # Bottleneck
    outputs_midl = self.middle_layer((
        voxel_features_list[-1],
        voxel_xyz_indices_list[-1],
        num_valid_voxels_list[-1],
    ), training)
    voxel_features_list[-1] = outputs_midl

    # Decoder
    for level in reversed(range(self.num_levels)):
      unpooled_features = sparse_voxel_net_utils.voxel_upsampling(
          pooled_voxel_features=voxel_features_list[level + 1],
          index_mapping=index_mapping_list[level])
      concatenated_features = tf.concat(
          [voxel_features_list[level], unpooled_features], axis=2)
      inputs_i = (concatenated_features, voxel_xyz_indices_list[level],
                  num_valid_voxels_list[level])
      conv_block_i = getattr(self, 'decoder_' + str(level))
      outputs_i = conv_block_i(inputs_i, training)
      voxel_features_list[level] = outputs_i

    outputs = {}
    # Output head convolutions
    for task_name in sorted(self.task_names_to_num_output_channels):
      conv_block_task_1 = getattr(self, f'{task_name}/final_conv1_block')
      conv_block_task_2 = getattr(self, f'{task_name}/final_conv2_block')
      net = conv_block_task_1(
          (voxel_features_list[0], voxel_xyz_indices, num_valid_voxels_list[0]),
          training)
      outputs[task_name] = conv_block_task_2(
          (net, voxel_xyz_indices, num_valid_voxels_list[0]), training)

    return outputs