Exemplo n.º 1
0
    def grad_fn(dy):
      """Compute gradients using a while loop to save memory."""
      support_keys_id = tf.identity(support_keys)
      support_values_id = tf.identity(support_values)
      initial = (0, tf.zeros(tf.shape(query_queries)[1:],
                             dtype=dy.dtype)[tf.newaxis, :][:zero_dim],
                 tf.zeros(tf.shape(query_values)[1:],
                          dtype=dy.dtype)[tf.newaxis, :][:zero_dim],
                 tf.zeros(tf.shape(support_keys_id), dtype=dy.dtype),
                 tf.zeros(tf.shape(support_values_id), dtype=dy.dtype))

      def loop_body(idx, qq_grad, qv_grad, sk_grad, sv_grad):
        """Compute gradients for a single query."""
        qq = query_queries[idx:idx + 1]
        qv = query_values[idx:idx + 1]
        x = self._get_dist(qq, qv, support_keys_id, support_values_id, labels)
        grads = tf.gradients(
            x, [qq, qv, support_keys_id, support_values_id],
            grad_ys=dy[:, idx:idx + 1])
        qq_grad = tf.concat([qq_grad, grads[0]], axis=0)
        qv_grad = tf.concat([qv_grad, grads[1]], axis=0)
        sk_grad += grads[2]
        sv_grad += grads[3]
        return (idx + 1, qq_grad, qv_grad, sk_grad, sv_grad)

      agg_grads = tf.while_loop(
          lambda *arg: arg[0] < tf.shape(query_queries)[0],
          loop_body,
          initial,
          parallel_iterations=1)
      return agg_grads[1:] + (None,)
Exemplo n.º 2
0
    def detailed_forward_pass(self, data):
        """Returns all information from a forward pass of the `OptimizationLearner`.

    Args:
      data: A `meta_dataset.providers.Episode` containing the data for the
        episode.

    Returns:
      A `collections.NamedTuple` that contains the results of the forward pass.
    """
        # Loop initialization.
        init_loop_variables = self.task_parameters
        init_loop_variable_refs = [
            v.experimental_ref() for v in init_loop_variables
        ]

        # Construct ops for data-dependent episodic initialization.
        episodic_init_ops = self.episodic_init_ops(
            labels=data.support_labels,
            embeddings=self.embedding_fn(data.support_images, training=True),
            task_parameters=init_loop_variables,
        )

        def _forward_pass(iteration_idx_, variables_mapping_, images_,
                          onehot_labels_):
            """Helper function to compute the outputs of a forward pass."""

            with self.embedding_fn.reparameterize(variables_mapping_):
                # TODO(eringrant): Implement non-transductive batch normalization (i.e.,
                # pass the support set statistics through the query set forward pass.
                embeddings_ = self.embedding_fn(images_, training=True)

            # TODO(eringrant): `head_fn` is an attribute of the subclass.
            with self.head_fn.reparameterize(variables_mapping_):
                predictions_ = self.head_fn(embeddings_)[:, :data.way]

            accuracy_ = tf.reduce_mean(input_tensor=self.compute_accuracy(
                onehot_labels=onehot_labels_, predictions=predictions_))

            inner_objective_ = self.inner_objective(
                onehot_labels=onehot_labels_,
                predictions=predictions_,
                iteration_idx=iteration_idx_)

            outer_objective_ = self.outer_objective(
                onehot_labels=onehot_labels_,
                predictions=predictions_,
            )

            return ForwardPass(
                embeddings=embeddings_,
                predictions=predictions_,
                inner_objective_value=inner_objective_,
                outer_objective_value=outer_objective_,
                accuracy=accuracy_,
            )

        def _objective_fn(loop_variables_, iteration_idx_):
            """Evaluate the support set objective given `loop_variables_`."""

            # Get attribute paths for the loop_variables.
            loop_variables_mapping_ = dict(
                zip(init_loop_variable_refs, loop_variables_))

            adaptation_support_results = _forward_pass(
                iteration_idx_=iteration_idx_,
                variables_mapping_=loop_variables_mapping_,
                images_=data.support_images,
                onehot_labels_=data.onehot_support_labels)

            return adaptation_support_results.inner_objective_value

        def _e_step(loop_variables_):
            """Evaluate expectations given `loop_variables_`."""

            # Get attribute paths for the loop_variables.
            loop_variables_dict_ = dict(
                zip(init_loop_variable_refs, loop_variables_))

            with self.embedding_fn.reparameterize(loop_variables_dict_):
                # TODO(eringrant): training to True for normalization with batch stats.
                # Figure out the appropriate way to pass this around.
                train_embeddings_ = self.embedding_fn(data.train_images,
                                                      training=True)

            class_embeddings_ = learner_base.class_specific_data(
                data.onehot_train_labels, train_embeddings_, self.logit_dim)

            def _compute_responsibilities(examples_, class_idx):
                train_predictions_ = tf.squeeze(self.head_fn(
                    embeddings=examples_,
                    components=True,
                    class_idx=[class_idx]),
                                                axis=1)
                return tf.nn.softmax(train_predictions_, axis=-1)

            with self.head_fn.reparameterize(loop_variables_dict_):
                class_responsibilities_ = [
                    _compute_responsibilities(embeddings_, class_idx=i)
                    for i, embeddings_ in enumerate(class_embeddings_)
                ]

            return class_embeddings_, class_responsibilities_

        def _m_step(preupdate_vars, all_embeddings_, all_responsibilities_):
            """Compute parameter estimates given `loop_variables_`."""

            means, log_scales, logits = zip(
                *map(reparameterizable_distributions.fit_gaussian_mixture,
                     all_embeddings_, all_responsibilities_,
                     itertools.repeat(self.head_fn.damping)))

            def flatten(x):
                return list(itertools.chain.from_iterable(x))

            means = flatten(means)
            log_scales = flatten(log_scales)
            logits = flatten(logits)

            if not self.head_fn.estimate_loc:
                means = [None for _ in means]

            if not self.head_fn.estimate_scale:
                log_scales = [None for _ in log_scales]

            if not self.head_fn.estimate_logits:
                logits = [None for _ in logits]

            updated_vars = means + log_scales + logits

            # Replace constant variables.
            # TODO(eringrant): This interface differs from just excluding these
            # variables from `task_variables`.
            no_none_updated_vars = []
            for preupdate_var, updated_var in zip(preupdate_vars,
                                                  updated_vars):
                if updated_var is None:
                    no_none_updated_vars.append(preupdate_var)
                else:
                    no_none_updated_vars.append(updated_var)

            # TODO(eringrant): This assumes an ordering of mean, log_scales,
            # mixing_logits.
            return no_none_updated_vars

        # Loop body.
        with tf.control_dependencies(episodic_init_ops):

            # Inner loop of expectation maximization.
            num_em_steps = self.getattr('num_em_steps', 0)
            if num_em_steps > 0:
                loop_variables = em_loop(num_updates=self.num_em_steps,
                                         e_step=_e_step,
                                         m_step=_m_step,
                                         variables=loop_variables)

            # Inner loop of gradient-based optimization.
            num_optimizer_steps = (self.num_update_steps +
                                   (self.additional_evaluation_update_steps
                                    if not self.is_training else 0))
            if num_optimizer_steps > 0:
                # pylint: disable=no-value-for-parameter
                final_loop_variables = optimizer_loop(
                    num_updates=num_optimizer_steps,
                    objective_fn=_objective_fn,
                    update_fn=self.update_fn,
                    variables=init_loop_variables,
                    first_order=self.first_order,
                    clip_grad_norm=self.clip_grad_norm,
                )
                # pylint: enable=no-value-for-parameter

            # If no inner loop adaptation is performed, ensure the episodic
            # initialization is still part of the graph via a control dependency.
            if num_optimizer_steps + num_em_steps == 0:
                loop_variables = [tf.identity(v) for v in init_loop_variables]

        # Get variable references to use when remapping the loop_variables.
        init_loop_variables_mapping = dict(
            zip(init_loop_variable_refs, init_loop_variables))
        final_loop_variables_mapping = dict(
            zip(init_loop_variable_refs, final_loop_variables))

        # Collect statistics about the inner optimization.
        with tf.compat.v1.name_scope('pre-adaptation'):
            with tf.compat.v1.name_scope('support'):
                pre_adaptation_support_results = _forward_pass(
                    iteration_idx_=0,
                    variables_mapping_=init_loop_variables_mapping,
                    images_=data.support_images,
                    onehot_labels_=data.onehot_support_labels)

            with tf.compat.v1.name_scope('query'):
                pre_adaptation_query_results = _forward_pass(
                    iteration_idx_=0,
                    variables_mapping_=init_loop_variables_mapping,
                    images_=data.query_images,
                    onehot_labels_=data.onehot_query_labels)

        with tf.compat.v1.name_scope('post-adaptation'):
            with tf.compat.v1.name_scope('support'):
                post_adaptation_support_results = _forward_pass(
                    iteration_idx_=num_optimizer_steps,
                    variables_mapping_=final_loop_variables_mapping,
                    images_=data.support_images,
                    onehot_labels_=data.onehot_support_labels,
                )

            with tf.compat.v1.name_scope('query'):
                post_adaptation_query_results = _forward_pass(
                    iteration_idx_=num_optimizer_steps,
                    variables_mapping_=final_loop_variables_mapping,
                    images_=data.query_images,
                    onehot_labels_=data.onehot_query_labels,
                )

        def _support_module_objective_fn(module_variables_,
                                         module_variable_refs_):
            """Evaluate the query set objective given `module_variables_`."""
            # Use the values of the parameters at convergence as the default value.
            variables_mapping_ = final_loop_variables_mapping.copy()

            # Loop over and replace the module-specific variables.
            for module_variable_ref, module_variable in zip(
                    module_variable_refs_, module_variables_):
                variables_mapping_[module_variable_ref] = module_variable

            adaptation_query_results = _forward_pass(
                iteration_idx_=num_optimizer_steps,
                variables_mapping_=variables_mapping_,
                images_=data.support_images,
                onehot_labels_=data.onehot_support_labels,
            )

            return adaptation_query_results.inner_objective_value

        def _query_module_objective_fn(module_variables_,
                                       module_variable_refs_):
            """Evaluate the query set objective given `module_variables_`."""
            # Use the values of the parameters at convergence as the default value.
            variables_mapping_ = final_loop_variables_mapping.copy()

            # Loop over and replace the module-specific variables.
            for module_variable_ref, module_variable in zip(
                    module_variable_refs_, module_variables_):
                variables_mapping_[module_variable_ref] = module_variable

            adaptation_query_results = _forward_pass(
                iteration_idx_=num_optimizer_steps,
                variables_mapping_=variables_mapping_,
                images_=data.query_images,
                onehot_labels_=data.onehot_query_labels)

            return adaptation_query_results.inner_objective_value

        return Adaptation(
            pre_adaptation_support_results=pre_adaptation_support_results,
            post_adaptation_support_results=post_adaptation_support_results,
            pre_adaptation_query_results=pre_adaptation_query_results,
            post_adaptation_query_results=post_adaptation_query_results,
            objective_fn=_objective_fn,
            support_module_objective_fn=_support_module_objective_fn,
            query_module_objective_fn=_query_module_objective_fn,
            forward_pass_fn=_forward_pass,
            init_loop_variables_mapping=init_loop_variables_mapping,
            final_loop_variables_mapping=final_loop_variables_mapping,
        )
Exemplo n.º 3
0
    def forward_pass(self, data):
        """Computes the query logits for the given episode `data`."""

        if self.film_init == 'scratch':
            self.film_selector = None
        elif self.film_init == 'imagenet':
            # Note: this makes the assumption that the first set of learned FiLM
            # parameters corresponds to the ImageNet dataset. Otherwise, the
            # following line should be changed appropriately.
            self.film_selector = 0
        elif self.film_init in ['blender', 'blender_hard']:
            dataset_logits = functional_backbones.dataset_classifier(
                data.support_images)
            if self.film_init == 'blender_hard':
                # Select only the argmax entry.
                self.film_selector = tf.one_hot(
                    tf.math.argmax(dataset_logits, axis=-1),
                    depth=tf.shape(dataset_logits)[1])
            else:
                # Take a convex combination.
                self.film_selector = tf.nn.softmax(dataset_logits, axis=-1)

        if self.num_steps:
            # Initial forward pass, required for the `unused_op` below and for placing
            # variables in tf.trainable_variables() for the below block to pick up.
            loss = self._compute_losses(data, compute_on_query=False)['loss']

            # Pick out the variables to optimize.
            self.opt_vars = []
            for var in tf.trainable_variables():
                if '_for_film_learner' in var.name:
                    self.opt_vars.append(var)
            tf.logging.info('FiLMLearner will optimize vars: {}'.format(
                self.opt_vars))

        for i in range(self.num_steps):
            if i == 0:
                # Re-initialize the variables to optimize for the new episode, to ensure
                # the FiLM parameters aren't re-used across tasks of a given dataset.
                vars_reset = tf.variables_initializer(var_list=self.opt_vars)
                # Adam related variables are created when minimize() is called.
                # We create an unused op here to put all adam varariables under
                # the 'adam_opt' namescope and create a reset op to reinitialize
                # these variables before the first finetune step.
                with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE):
                    unused_op = self.opt.minimize(loss, var_list=self.opt_vars)
                adam_reset = tf.variables_initializer(self.opt.variables())

                with tf.control_dependencies([vars_reset, adam_reset, loss] +
                                             self.opt_vars):
                    print_op = tf.no_op()
                    if self.debug_log:
                        print_op = tf.print([
                            'step: %d' % i, self.opt_vars[0][0], 'loss:', loss
                        ],
                                            summarize=-1)

                    with tf.control_dependencies([print_op]):
                        # Get the train op.
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

            else:
                with tf.control_dependencies([train_op, loss, acc] +
                                             self.opt_vars +
                                             [query_loss, query_acc] *
                                             int(self.debug_log)):

                    print_op = tf.no_op()
                    if self.debug_log:
                        print_list = [
                            '################',
                            'step: %d' % i,
                            self.opt_vars[0][0],
                            'support loss:',
                            loss,
                            'query loss:',
                            query_loss,
                            'support acc:',
                            acc,
                            'query acc:',
                            query_acc,
                        ]
                        print_op = tf.print(print_list)

                    with tf.control_dependencies([print_op]):
                        # Get the train op (the loss is returned just for printing).
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

        # Training is now over, compute the final query logits.
        dependency_list = [] if not self.num_steps else [train_op
                                                         ] + self.opt_vars
        with tf.control_dependencies(dependency_list):
            results = self._compute_losses(data, compute_on_query=True)
            (loss, query_loss, query_logits, acc,
             query_acc) = (results['loss'], results['query_loss'],
                           results['query_logits'], results['acc'],
                           results['query_acc'])

            print_op = tf.no_op()
            if self.debug_log:
                print_op = tf.print([
                    'Done training',
                    'support loss:',
                    loss,
                    'query loss:',
                    query_loss,
                    'support acc:',
                    acc,
                    'query acc:',
                    query_acc,
                ])
            with tf.control_dependencies([print_op]):
                query_logits = tf.identity(query_logits)

        return query_logits
Exemplo n.º 4
0
 def get_updated_global_step(self):
     with tf.control_dependencies([self.train_op]):
         global_step = tf.identity(tf.train.get_global_step())
     return global_step
Exemplo n.º 5
0
def get_train_op(loss,
                 learning_rate=0.001,
                 lr_decay_steps=10000,
                 lr_decay_rate=0.98,
                 gradient_clip_norm=3.0,
                 use_tpu=True,
                 variables=None):
    """Get training operation with gradient clipping and learning rate decay.

  Distilled from tf.contrib.layers.optimize_loss().
  Args:
    loss: Scalar tensor of the loss function.
    learning_rate: Scalar initial learning rate.
    lr_decay_steps: Exponential decay timescale.
    lr_decay_rate: Exponential decay magnitude.
    gradient_clip_norm: Global norm by which to scale gradients.
    use_tpu: Use tpu for training.
    variables: List of variables to optimize. tf.trainable_variables() if None.

  Returns:
    train_op: Operation that runs one iteration of training.
  """
    global_step = tf.train.get_or_create_global_step()

    with tf.variable_scope('training', values=[loss, global_step]):
        # Make sure update ops run before computing loss.
        update_ops = list(set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
        with tf.control_dependencies(update_ops):
            loss = tf.identity(loss)

        # Learning rate variable, with decay.
        learning_rate_decay_fn = functools.partial(tf.train.exponential_decay,
                                                   decay_steps=lr_decay_steps,
                                                   decay_rate=lr_decay_rate,
                                                   staircase=True)
        lr = tf.get_variable(
            'learning_rate', [],
            trainable=False,
            initializer=tf.constant_initializer(learning_rate))
        lr = learning_rate_decay_fn(lr, global_step)

        # Optimizer.
        opt = tf.train.AdamOptimizer(lr)
        if use_tpu:
            opt = tf.tpu.CrossShardOptimizer(opt)

        # All trainable variables, if specific variables are not specified.
        if variables is None:
            variables = tf.trainable_variables()

        # Compute gradients.
        gradients = opt.compute_gradients(loss,
                                          variables,
                                          colocate_gradients_with_ops=False)

        # Optionally clip gradients by global norm.
        if isinstance(gradient_clip_norm, float):
            gradients = _clip_gradients_by_norm(gradients, gradient_clip_norm)

        # Create gradient updates.
        grad_updates = opt.apply_gradients(gradients,
                                           global_step=global_step,
                                           name='train')

        # Ensure the train_op computes grad_updates.
        with tf.control_dependencies([grad_updates]):
            train_op = tf.identity(loss)

        return train_op