예제 #1
0
    def sample(self, reinitialize=False):
        """Samples weights for the network, displaying training and test errors."""

        # we store training loss for sanity check
        training_loss_h, test_loss_h = [], []

        # Keep only one over thinning_interval sample
        num_iters = self.num_samples * self.thinning_interval

        # Saving the weights of the last layer
        num_ll_weights = int((self.dim_input + 1) * self.num_classes)
        sampled_weights = np.zeros((self.num_samples, num_ll_weights))

        # Random initialization of the weights if needed.
        if reinitialize:
            self.sess.run(tf.variables_initializer(self.ll_vars))

        # sampling
        init_t = time.time()
        print('-----------------------------------------------------')
        print('Starting sampling of the Bayesian Neural Network by ' +
              six.ensure_str(self.sampler))
        for i in np.arange(0, num_iters):
            batch_x, batch_y = self.next_batch()
            feed_dict = {self.x: batch_x, self.y: batch_y}
            self.sess.run([self.train_op], feed_dict=feed_dict)

            if (i + 1) % self.thinning_interval == 0:
                summary, loss_v, ll_vars_v = self.sess.run(
                    [self.all_summaries, self.loss, self.ll_vars_concat],
                    feed_dict=feed_dict)
                training_loss_h.append(loss_v)
                self.writer.add_summary(summary, i + 1)
                self.writer.flush()
                sampled_weights[
                    i // self.thinning_interval, :] = ll_vars_v.flatten('F')
                feed_dict_test = {self.x: self.x_test, self.y: self.y_test}
                test_loss_v = self.sess.run([self.loss],
                                            feed_dict=feed_dict_test)
                test_loss_h.append(test_loss_v)
                msg = ('{} steps. Loss = {}. \t Test Loss = {}.').format(
                    str(i + 1), str(loss_v), str(test_loss_v))
                print(msg)
                self.saver.save(self.sess,
                                os.path.join(self.working_dir,
                                             'weights/saved-last-weights'),
                                global_step=i + 1,
                                write_meta_graph=False)

        print('-----------------------------------------------------')
        print('Training complete after {} seconds.'.format(time.time() -
                                                           init_t))
        print('-----------------------------------------------------')

        self.writer.close()

        return training_loss_h, test_loss_h, sampled_weights
예제 #2
0
    def sample(self, reinitialize=False):
        """Samples weights after training, for bootstrap we only need one sample."""

        # we store training loss for sanity check
        training_loss_h, test_loss_h = [], []

        # we first need to train the model to convergence (on bootstrapped data)
        num_training_iters = self.num_training_iters

        # saving the weights of the last layer
        num_ll_weights = int((self.dim_input + 1) * self.num_classes)
        sampled_weights = np.zeros((1, num_ll_weights))

        # random initialization of the weights if needed.
        if reinitialize:
            self.sess.run(tf.variables_initializer(self.ll_vars))

        # sampling
        init_t = time.time()
        print('-----------------------------------------------------')
        print('Starting sampling of the Bootstrapped Neural Network.')

        # we first train the model using bootstrap
        for i in np.arange(num_training_iters):

            # train step
            batch_x, batch_y = self.next_batch()
            feed_dict = {self.x: batch_x, self.y: batch_y}
            _, summary, loss_v, = self.sess.run(
                [self.train_op, self.all_summaries, self.loss],
                feed_dict=feed_dict)
            self.writer.add_summary(summary, i)
            self.writer.flush()
            training_loss_h.append(loss_v)

            # test error every 100 iters
            if i % 100 == 0:
                feed_dict = {self.x: self.x_test, self.y: self.y_test}
                test_loss_v = self.sess.run([self.loss], feed_dict=feed_dict)
                test_loss_h.append(test_loss_v)
                msg = ('{} steps. Loss = {}. \t Test Loss = {}.').format(
                    str(i), str(loss_v), str(test_loss_v))
                print(msg)

        # finally, we store the last model
        ll_vars_v = self.sess.run(self.ll_vars_concat, feed_dict=feed_dict)
        sampled_weights[0, :] = ll_vars_v.flatten('F')

        self.saver.save(self.sess,
                        os.path.join(
                            self.working_dir,
                            'weights/saved-boot{}-last-weights'.format(
                                self.worker_id)),
                        global_step=0,
                        write_meta_graph=False)

        print('-----------------------------------------------------')
        print('Training complete after {} seconds.'.format(time.time() -
                                                           init_t))
        print('-----------------------------------------------------')

        self.writer.close()

        return training_loss_h, test_loss_h, sampled_weights
예제 #3
0
    def forward_pass(self, data):
        """Computes the query logits for the given episode `data`."""

        if self.film_init == 'scratch':
            self.film_selector = None
        elif self.film_init == 'imagenet':
            # Note: this makes the assumption that the first set of learned FiLM
            # parameters corresponds to the ImageNet dataset. Otherwise, the
            # following line should be changed appropriately.
            self.film_selector = 0
        elif self.film_init in ['blender', 'blender_hard']:
            dataset_logits = functional_backbones.dataset_classifier(
                data.support_images)
            if self.film_init == 'blender_hard':
                # Select only the argmax entry.
                self.film_selector = tf.one_hot(
                    tf.math.argmax(dataset_logits, axis=-1),
                    depth=tf.shape(dataset_logits)[1])
            else:
                # Take a convex combination.
                self.film_selector = tf.nn.softmax(dataset_logits, axis=-1)

        if self.num_steps:
            # Initial forward pass, required for the `unused_op` below and for placing
            # variables in tf.trainable_variables() for the below block to pick up.
            loss = self._compute_losses(data, compute_on_query=False)['loss']

            # Pick out the variables to optimize.
            self.opt_vars = []
            for var in tf.trainable_variables():
                if '_for_film_learner' in var.name:
                    self.opt_vars.append(var)
            tf.logging.info('FiLMLearner will optimize vars: {}'.format(
                self.opt_vars))

        for i in range(self.num_steps):
            if i == 0:
                # Re-initialize the variables to optimize for the new episode, to ensure
                # the FiLM parameters aren't re-used across tasks of a given dataset.
                vars_reset = tf.variables_initializer(var_list=self.opt_vars)
                # Adam related variables are created when minimize() is called.
                # We create an unused op here to put all adam varariables under
                # the 'adam_opt' namescope and create a reset op to reinitialize
                # these variables before the first finetune step.
                with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE):
                    unused_op = self.opt.minimize(loss, var_list=self.opt_vars)
                adam_reset = tf.variables_initializer(self.opt.variables())

                with tf.control_dependencies([vars_reset, adam_reset, loss] +
                                             self.opt_vars):
                    print_op = tf.no_op()
                    if self.debug_log:
                        print_op = tf.print([
                            'step: %d' % i, self.opt_vars[0][0], 'loss:', loss
                        ],
                                            summarize=-1)

                    with tf.control_dependencies([print_op]):
                        # Get the train op.
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

            else:
                with tf.control_dependencies([train_op, loss, acc] +
                                             self.opt_vars +
                                             [query_loss, query_acc] *
                                             int(self.debug_log)):

                    print_op = tf.no_op()
                    if self.debug_log:
                        print_list = [
                            '################',
                            'step: %d' % i,
                            self.opt_vars[0][0],
                            'support loss:',
                            loss,
                            'query loss:',
                            query_loss,
                            'support acc:',
                            acc,
                            'query acc:',
                            query_acc,
                        ]
                        print_op = tf.print(print_list)

                    with tf.control_dependencies([print_op]):
                        # Get the train op (the loss is returned just for printing).
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

        # Training is now over, compute the final query logits.
        dependency_list = [] if not self.num_steps else [train_op
                                                         ] + self.opt_vars
        with tf.control_dependencies(dependency_list):
            results = self._compute_losses(data, compute_on_query=True)
            (loss, query_loss, query_logits, acc,
             query_acc) = (results['loss'], results['query_loss'],
                           results['query_logits'], results['acc'],
                           results['query_acc'])

            print_op = tf.no_op()
            if self.debug_log:
                print_op = tf.print([
                    'Done training',
                    'support loss:',
                    loss,
                    'query loss:',
                    query_loss,
                    'support acc:',
                    acc,
                    'query acc:',
                    query_acc,
                ])
            with tf.control_dependencies([print_op]):
                query_logits = tf.identity(query_logits)

        return query_logits
예제 #4
0
  def compute_logits(self, data):
    """Computes the class logits for the episode.

    Args:
      data: A `meta_dataset.providers.Episode`.

    Returns:
      The query set logits as a [num_query_images, way] matrix.

    Raises:
      ValueError: Distance must be one of l2 or cosine.
    """
    # ------------------------ Finetuning -------------------------------
    # Possibly make copies of embedding variables, if they will get modified.
    # This is for making temporary-only updates to the embedding network
    # which will not persist after the end of the episode.
    make_copies = self.finetune_all_layers

    # TODO(eringrant): Reduce the number of times the embedding function graph
    # is built with the same input.
    support_embeddings_params_moments = self.embedding_fn(
        data.support_images, self.is_training)
    support_embeddings = support_embeddings_params_moments['embeddings']
    support_embeddings_var_dict = support_embeddings_params_moments['params']

    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         support_embeddings_var_dict, make_copies)
    embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops)

    # Compute the initial training loss (only for printing purposes). This
    # line is also needed for adding the fc variables to the graph so that the
    # tf.all_variables() line below detects them.
    logits = self._fc_layer(support_embeddings)[:, 0:data.way]
    finetune_loss = self.compute_loss(
        onehot_labels=data.onehot_support_labels,
        predictions=logits,
    )

    # Decide which variables to finetune.
    fc_vars, vars_to_finetune = [], []
    for var in tf.trainable_variables():
      if 'fc_finetune' in var.name:
        fc_vars.append(var)
        vars_to_finetune.append(var)
    if self.finetune_all_layers:
      vars_to_finetune.extend(embedding_vars)
    logging.info('Finetuning will optimize variables: %s', vars_to_finetune)

    for i in range(self.num_finetune_steps):
      if i == 0:
        # Randomly initialize the fc layer.
        fc_reset = tf.variables_initializer(var_list=fc_vars)
        # Adam related variables are created when minimize() is called.
        # We create an unused op here to put all adam varariables under
        # the 'adam_opt' namescope and create a reset op to reinitialize
        # these variables before the first finetune step.
        adam_reset = tf.no_op()
        if self.finetune_with_adam:
          with tf.variable_scope('adam_opt'):
            unused_op = self.finetune_opt.minimize(
                finetune_loss, var_list=vars_to_finetune)
          adam_reset = tf.variables_initializer(self.finetune_opt.variables())
        with tf.control_dependencies(
            [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] +
            vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:',
                finetune_loss
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

      else:
        with tf.control_dependencies([finetune_op, finetune_loss] +
                                     vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i,
                vars_to_finetune[0][0, 0],
                'loss:',
                finetune_loss,
                'accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_support_labels, predictions=logits),
                'query accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_query_labels, predictions=query_logits),
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

    # Finetuning is now over, compute the query performance using the updated
    # fc layer, and possibly the updated embedding network.
    with tf.control_dependencies([finetune_op] + vars_to_finetune):
      query_embeddings = self.embedding_fn(
          data.query_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']
      query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

      if self.debug_log:
        # The train logits are computed only for printing.
        support_embeddings = self.embedding_fn(
            data.support_images,
            self.is_training,
            params=collections.OrderedDict(
                zip(embedding_vars_keys, embedding_vars)),
            reuse=True)['embeddings']
        logits = self._fc_layer(support_embeddings)[:, 0:data.way]

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print([
            'accuracy:',
            self.compute_accuracy(
                labels=data.onehot_support_labels, predictions=logits),
            'query accuracy:',
            self.compute_accuracy(
                labels=data.onehot_query_labels, predictions=query_logits),
        ])
      with tf.control_dependencies([print_op]):
        query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

    return query_logits
예제 #5
0
    def initialize_session(self):
        """Initializes a tf Session."""
        if ENABLE_TF_OPTIMIZATIONS:
            self.sess = tf.Session()
        else:
            rewriter_config = rewriter_config_pb2.RewriterConfig(
                disable_model_pruning=True,
                constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                remapping=rewriter_config_pb2.RewriterConfig.OFF,
                shape_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                function_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
                loop_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                memory_optimization=rewriter_config_pb2.RewriterConfig.
                NO_MEM_OPT)
            graph_options = tf.GraphOptions(rewrite_options=rewriter_config)
            session_config = tf.ConfigProto(graph_options=graph_options)
            self.sess = tf.Session(config=session_config)

        # Restore or initialize the variables.
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        if self.learner_config.checkpoint_for_eval:
            # Requested a specific checkpoint.
            self.saver.restore(self.sess,
                               self.learner_config.checkpoint_for_eval)
            tf.logging.info('Restored checkpoint: %s' %
                            self.learner_config.checkpoint_for_eval)
        else:
            # Continue from the latest checkpoint if one exists.
            # This handles fault-tolerance.
            latest_checkpoint = None
            if self.checkpoint_dir is not None:
                latest_checkpoint = tf.train.latest_checkpoint(
                    self.checkpoint_dir)
            if latest_checkpoint:
                self.saver.restore(self.sess, latest_checkpoint)
                tf.logging.info('Restored checkpoint: %s' % latest_checkpoint)
            else:
                tf.logging.info('No previous checkpoint.')
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())

        # For episodic models, potentially use pretrained weights at the start of
        # training. If this happens it will overwrite the embedding weights, but
        # taking care to not restore the Adam parameters.
        if self.learner_config.pretrained_checkpoint and not self.sess.run(
                tf.train.get_global_step()):
            self.saver.restore(self.sess,
                               self.learner_config.pretrained_checkpoint)
            tf.logging.info('Restored checkpoint: %s' %
                            self.learner_config.pretrained_checkpoint)
            # We only want the embedding weights of the checkpoint we just restored.
            # So we re-initialize everything that's not an embedding weight. Also,
            # since this episodic finetuning procedure is a different optimization
            # problem than the original training of the baseline whose embedding
            # weights are re-used, we do not reload ADAM's variables and instead learn
            # them from scratch.
            vars_to_reinit, embedding_var_names, vars_to_reinit_names = [], [], []
            for var in tf.global_variables():
                if (any(keyword in var.name for keyword in EMBEDDING_KEYWORDS)
                        and 'adam' not in var.name.lower()):
                    embedding_var_names.append(var.name)
                    continue
                vars_to_reinit.append(var)
                vars_to_reinit_names.append(var.name)
            tf.logging.info('Initializing all variables except for %s.' %
                            embedding_var_names)
            self.sess.run(tf.variables_initializer(vars_to_reinit))
            tf.logging.info('Re-initialized vars %s.' % vars_to_reinit_names)