def _build_train_op(self):
        """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """

        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_chosen_q = tf.reduce_sum(self._replay_qs *
                                        replay_action_one_hot,
                                        reduction_indices=1,
                                        name='replay_chosen_q')

        target = tf.stop_gradient(self._build_target_q_op())
        loss = tf.losses.huber_loss(target,
                                    replay_chosen_q,
                                    reduction=tf.losses.Reduction.NONE)

        update_priorities_op = self._replay.tf_set_priority(
            self._replay.indices, tf.sqrt(loss + 1e-10))

        target_priorities = self._replay.tf_get_priority(self._replay.indices)
        target_priorities = tf.math.add(target_priorities, 1e-10)
        target_priorities = 1.0 / tf.sqrt(target_priorities)
        target_priorities /= tf.reduce_max(target_priorities)

        weighted_loss = target_priorities * loss

        with tf.control_dependencies([update_priorities_op]):
            return self.optimizer.minimize(
                tf.reduce_mean(weighted_loss)), weighted_loss
예제 #2
0
  def _build_train_op(self):
    """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """
    target_distribution = tf.stop_gradient(self._build_target_distribution())

    # size of indices: batch_size x 1.
    indices = tf.range(tf.shape(self._replay_logits)[0])[:, None]
    # size of reshaped_actions: batch_size x 2.
    reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
    # For each element of the batch, fetch the logits for its selected action.
    chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions)

    loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=target_distribution,
        logits=chosen_action_logits)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=self.learning_rate,
        epsilon=self.optimizer_epsilon)

    update_priorities_op = self._replay.tf_set_priority(
        self._replay.indices, tf.sqrt(loss + 1e-10))

    target_priorities = self._replay.tf_get_priority(self._replay.indices)
    target_priorities = tf.math.add(target_priorities, 1e-10)
    target_priorities = 1.0 / tf.sqrt(target_priorities)
    target_priorities /= tf.reduce_max(target_priorities)

    weighted_loss = target_priorities * loss

    with tf.control_dependencies([update_priorities_op]):
      return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
예제 #3
0
def compute_train_class_proportions(episode, shots, dataset_spec):
    """Computes the proportion of each class' examples in the support set.

  Args:
    episode: An EpisodeDataset.
    shots: A 1D Tensor whose length is the `way' of the episode that stores the
      shots for this episode.
    dataset_spec: A DatasetSpecification.

  Returns:
    class_props: A 1D Tensor whose length is the `way' of the episode, storing
      for each class the proportion of its examples that are in the support set.
  """
    # Get the total number of examples of each class in the dataset.
    num_dataset_classes = len(dataset_spec.images_per_class)
    num_images_per_class = [
        dataset_spec.get_total_images_per_class(class_id)
        for class_id in range(num_dataset_classes)
    ]

    # Get the (absolute) class ID's that appear in the episode.
    class_ids, _ = tf.unique(episode.train_class_ids)  # [?, ]

    # Make sure that class_ids are valid indices of num_images_per_class. This is
    # important since tf.gather will fail silently and return zeros otherwise.
    num_classes = tf.shape(num_images_per_class)[0]
    check_valid_inds_op = tf.assert_less(class_ids, num_classes)
    with tf.control_dependencies([check_valid_inds_op]):
        # Get the total number of examples of each class that is in the episode.
        num_images_per_class = tf.gather(num_images_per_class,
                                         class_ids)  # [?, ]

    # Get the proportions of examples of each class that appear in the train set.
    class_props = tf.truediv(shots, num_images_per_class)
    return class_props
예제 #4
0
def assign_variables(variable_mapping):
    """Assign variables according to the provided `variable_mapping`.

  Args:
    variable_mapping: An iterable of variable pairs, each corresponding to a
      variable whose value is to be overwitten (destination) and a reference
      variable (source).

  Returns:
    If running in TensorFlow Eager mode, returns None; otherwise, returns a list
    of assignment operations.
  """
    for variable, reference_variable in variable_mapping:
        if tf.executing_eagerly():
            # Just perform the assignment.
            variable.assign(reference_variable)
        else:
            # Piggyback on the variable's initializer attribute, which is included in
            # `tf.global_variables_initializer`.
            initializer_ops = [variable._initializer_op]  # pylint: disable=protected-access
            if isinstance(reference_variable, tf.Variable):
                initializer_ops += [reference_variable._initializer_op]  # pylint: disable=protected-access
            with tf.control_dependencies(initializer_ops):
                assign_op = variable.assign(reference_variable)
            variable._initializer_op = assign_op  # pylint: disable=protected-access
예제 #5
0
  def _set_up_staging(self, transition):
    """Sets up staging ops for prefetching the next transition.

    This allows us to hide the py_func latency. To do so we use a staging area
    to pre-fetch the next batch of transitions.

    Args:
      transition: tuple of tf.Tensors with shape
        memory.get_transition_elements().

    Returns:
      prefetched_transition: tuple of tf.Tensors with shape
        memory.get_transition_elements() that have been previously prefetched.
    """
    transition_type = self.memory.get_transition_elements()

    # Create the staging area in CPU.
    prefetch_area = tf.contrib.staging.StagingArea(
        [shape_with_type.type for shape_with_type in transition_type])

    # Store prefetch op for tests, but keep it private -- users should not be
    # calling _prefetch_batch.
    self._prefetch_batch = prefetch_area.put(transition)
    initial_prefetch = tf.cond(
        tf.equal(prefetch_area.size(), 0),
        lambda: prefetch_area.put(transition), tf.no_op)

    # Every time a transition is sampled self.prefetch_batch will be
    # called. If the staging area is empty, two put ops will be called.
    with tf.control_dependencies([self._prefetch_batch, initial_prefetch]):
      prefetched_transition = prefetch_area.get()

    return prefetched_transition
예제 #6
0
 def get_train_op(self, global_step):
     """Returns the operation that performs a training update."""
     # UPDATE_OPS picks up batch_norm updates.
     update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
     with tf.control_dependencies(update_ops):
         train_op = self.optimizer.minimize(self.losses['train'],
                                            global_step=global_step)
     return train_op
    def transform(images, i_or_ij, is_flow, crop_height, crop_width,
                  shift_heights, shift_widths, resize):
        # Expect (i, j) for flows and masks and i for images.
        if isinstance(i_or_ij, int):
            i = i_or_ij
            # Flow needs i and j.
            assert not is_flow
        else:
            i, j = i_or_ij

        if is_flow:
            shifts = tf.stack([shift_heights, shift_widths], axis=-1)
            flow_offset = shifts[i] - shifts[j]
            images = images + tf.cast(flow_offset, tf.float32)

        shift_height = shift_heights[i]
        shift_width = shift_widths[i]
        height = images.shape[-3]
        width = images.shape[-2]

        # Assert that the cropped bounding box does not go out of the image frame.
        op1 = tf.compat.v1.assert_greater_equal(crop_height + shift_height, 0)
        op2 = tf.compat.v1.assert_greater_equal(crop_width + shift_width, 0)
        op3 = tf.compat.v1.assert_less_equal(
            height - crop_height + shift_height, height)
        op4 = tf.compat.v1.assert_less_equal(width - crop_width + shift_width,
                                             width)
        op5 = tf.compat.v1.assert_greater(
            height,
            2 * crop_height,
            message='Image height is too small for cropping.')
        op6 = tf.compat.v1.assert_greater(
            width,
            2 * crop_width,
            message='Image width is too small for cropping.')
        with tf.control_dependencies([op1, op2, op3, op4, op5, op6]):
            images = images[:, crop_height + shift_height:height -
                            crop_height + shift_height, crop_width +
                            shift_width:width - crop_width + shift_width, :]
        if resize:
            images = uflow_utils.resize(images, height, width, is_flow=is_flow)
            images.set_shape((images.shape[0], height, width, images.shape[3]))
        else:
            images.set_shape((images.shape[0], height - 2 * crop_height,
                              width - 2 * crop_width, images.shape[3]))
        return images
  def transform(images, is_flow, crop_height, crop_width, resize):

    height = images.shape[-3]
    width = images.shape[-2]

    op5 = tf.compat.v1.assert_greater(
        height,
        2 * crop_height,
        message='Image height is too small for cropping.')
    op6 = tf.compat.v1.assert_greater(
        width, 2 * crop_width, message='Image width is too small for cropping.')
    with tf.control_dependencies([op5, op6]):
      images = images[:, crop_height:height - crop_height,
                      crop_width:width - crop_width, :]
    if resize:
      images = smurf_utils.resize(images, height, width, is_flow=is_flow)
      images.set_shape((images.shape[0], height, width, images.shape[3]))
    else:
      images.set_shape((images.shape[0], height - 2 * crop_height,
                        width - 2 * crop_width, images.shape[3]))
    return images
예제 #9
0
    def _body(step, *args):
      """The inner update loop body."""
      updated_embedding_vars = args[0:num_embedding_vars]
      updated_fc_vars = args[num_embedding_vars:num_embedding_vars +
                             num_fc_vars]
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['embeddings']

      updated_fc_weights, updated_fc_bias = updated_fc_vars
      support_logits = tf.matmul(support_embeddings,
                                 updated_fc_weights) + updated_fc_bias

      support_logits = support_logits[:, 0:data.way]
      loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels,
                                             support_logits)

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss])

      with tf.control_dependencies([print_op]):
        updated_embedding_vars = gradient_descent_step(
            loss, updated_embedding_vars, self.first_order,
            self.adapt_batch_norm, self.alpha, False)['updated_vars']
        updated_fc_vars = gradient_descent_step(loss, updated_fc_vars,
                                                self.first_order,
                                                self.adapt_batch_norm,
                                                self.alpha,
                                                False)['updated_vars']

        step = step + 1
      return tuple([step] + list(updated_embedding_vars) +
                   list(updated_fc_vars))
예제 #10
0
def linear_classifier(embeddings, num_classes, cosine_classifier,
                      cosine_logits_multiplier, use_weight_norm, weight_decay):
    """Forward pass through a linear classifier, or possibly a cosine classifier.

  Args:
    embeddings: A Tensor of size [batch size, embedding dim].
    num_classes: An integer; the dimension of the classification.
    cosine_classifier: A bool. If true, a cosine classifier is used, which does
      not require a bias.
    cosine_logits_multiplier: A float. Only used if cosine_classifier is True,
      and multiplies the resulting logits.
    use_weight_norm: A bool. Whether weight norm was used. If so, then if using
      cosine classifier, normalize only the embeddings but not the weights.
    weight_decay: A float; the scalar multiple on the L2 regularization of the
      weight matrix.

  Returns:
    logits: A Tensor of size [batch size, num outputs].
  """

    embedding_dims = embeddings.get_shape().as_list()[-1]

    if use_weight_norm:
        # A variable to keep track of whether the initialization has already
        # happened.
        data_dependent_init_done = tf.get_variable('data_dependent_init_done',
                                                   initializer=0,
                                                   dtype=tf.int32,
                                                   trainable=False)

        w_fc = tf.get_variable('w_fc', [embedding_dims, num_classes],
                               initializer=tf.random_normal_initializer(
                                   0, 0.05),
                               trainable=True)
        # This init is temporary as it needs to be done in a data-dependent way.
        # It will be overwritten during the first forward pass through this layer.
        g = tf.get_variable('g',
                            dtype=tf.float32,
                            initializer=tf.ones([num_classes]),
                            trainable=True)
        b_fc = None
        if not cosine_classifier:
            # Also initialize a bias.
            b_fc = tf.get_variable('b_fc',
                                   initializer=tf.zeros([num_classes]),
                                   trainable=True)

        def _do_data_dependent_init():
            """Returns ops for the data-dependent init of g and maybe b_fc."""
            w_fc_normalized = tf.nn.l2_normalize(w_fc.read_value(), [0])
            output_init = tf.matmul(embeddings, w_fc_normalized)
            mean_init, var_init = tf.nn.moments(output_init, [0])
            # Data-dependent init values.
            g_init_value = 1. / tf.sqrt(var_init + 1e-10)
            ops = [tf.assign(g, g_init_value)]
            if not cosine_classifier:
                # Also initialize a bias in a data-dependent way.
                b_fc_init_value = -mean_init * g_init_value
                ops.append(tf.assign(b_fc, b_fc_init_value))
            # Mark that the data-dependent initialization is done to prevent it from
            # happening again in the future.
            ops.append(tf.assign(data_dependent_init_done, 1))
            return tf.group(*ops)

        # Possibly perform data-dependent init (if it hasn't been done already).
        init_op = tf.cond(tf.equal(data_dependent_init_done, 0),
                          _do_data_dependent_init, tf.no_op)

        with tf.control_dependencies([init_op]):
            # Apply weight normalization.
            w_fc *= g / tf.sqrt(tf.reduce_sum(tf.square(w_fc), [0]))
            # Forward pass through the layer defined by w_fc and b_fc.
            logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc,
                                                    cosine_classifier,
                                                    cosine_logits_multiplier,
                                                    True)

    else:
        # No weight norm.
        w_fc = functional_backbones.weight_variable(
            [embedding_dims, num_classes], weight_decay=weight_decay)
        b_fc = None
        if not cosine_classifier:
            # Also initialize a bias.
            b_fc = functional_backbones.bias_variable([num_classes])
        # Forward pass through the layer defined by w_fc and b_fc.
        logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc,
                                                cosine_classifier,
                                                cosine_logits_multiplier,
                                                False)
    return logits
예제 #11
0
def project_distribution(supports, weights, target_support,
                         validate_args=False):
  """Projects a batch of (support, weights) onto target_support.

  Based on equation (7) in (Bellemare et al., 2017):
    https://arxiv.org/abs/1707.06887
  In the rest of the comments we will refer to this equation simply as Eq7.

  This code is not easy to digest, so we will use a running example to clarify
  what is going on, with the following sample inputs:
    * supports =       [[0, 2, 4, 6, 8],
                        [1, 3, 4, 5, 6]]
    * weights =        [[0.1, 0.6, 0.1, 0.1, 0.1],
                        [0.1, 0.2, 0.5, 0.1, 0.1]]
    * target_support = [4, 5, 6, 7, 8]
  In the code below, comments preceded with 'Ex:' will be referencing the above
  values.

  Args:
    supports: Tensor of shape (batch_size, num_dims) defining supports for the
      distribution.
    weights: Tensor of shape (batch_size, num_dims) defining weights on the
      original support points. Although for the CategoricalDQN agent these
      weights are probabilities, it is not required that they are.
    target_support: Tensor of shape (num_dims) defining support of the projected
      distribution. The values must be monotonically increasing. Vmin and Vmax
      will be inferred from the first and last elements of this tensor,
      respectively. The values in this tensor must be equally spaced.
    validate_args: Whether we will verify the contents of the
      target_support parameter.

  Returns:
    A Tensor of shape (batch_size, num_dims) with the projection of a batch of
    (support, weights) onto target_support.

  Raises:
    ValueError: If target_support has no dimensions, or if shapes of supports,
      weights, and target_support are incompatible.
  """
  target_support_deltas = target_support[1:] - target_support[:-1]
  # delta_z = `\Delta z` in Eq7.
  delta_z = target_support_deltas[0]
  validate_deps = []
  supports.shape.assert_is_compatible_with(weights.shape)
  supports[0].shape.assert_is_compatible_with(target_support.shape)
  target_support.shape.assert_has_rank(1)
  if validate_args:
    # Assert that supports and weights have the same shapes.
    validate_deps.append(
        tf.Assert(
            tf.reduce_all(tf.equal(tf.shape(supports), tf.shape(weights))),
            [supports, weights]))
    # Assert that elements of supports and target_support have the same shape.
    validate_deps.append(
        tf.Assert(
            tf.reduce_all(
                tf.equal(tf.shape(supports)[1], tf.shape(target_support))),
            [supports, target_support]))
    # Assert that target_support has a single dimension.
    validate_deps.append(
        tf.Assert(
            tf.equal(tf.size(tf.shape(target_support)), 1), [target_support]))
    # Assert that the target_support is monotonically increasing.
    validate_deps.append(
        tf.Assert(tf.reduce_all(target_support_deltas > 0), [target_support]))
    # Assert that the values in target_support are equally spaced.
    validate_deps.append(
        tf.Assert(
            tf.reduce_all(tf.equal(target_support_deltas, delta_z)),
            [target_support]))

  with tf.control_dependencies(validate_deps):
    # Ex: `v_min, v_max = 4, 8`.
    v_min, v_max = target_support[0], target_support[-1]
    # Ex: `batch_size = 2`.
    batch_size = tf.shape(supports)[0]
    # `N` in Eq7.
    # Ex: `num_dims = 5`.
    num_dims = tf.shape(target_support)[0]
    # clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7.
    # Ex: `clipped_support = [[[ 4.  4.  4.  6.  8.]]
    #                         [[ 4.  4.  4.  5.  6.]]]`.
    clipped_support = tf.clip_by_value(supports, v_min, v_max)[:, None, :]
    # Ex: `tiled_support = [[[[ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]]
    #                        [[ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]]]]`.
    tiled_support = tf.tile([clipped_support], [1, 1, num_dims, 1])
    # Ex: `reshaped_target_support = [[[ 4.]
    #                                  [ 5.]
    #                                  [ 6.]
    #                                  [ 7.]
    #                                  [ 8.]]
    #                                 [[ 4.]
    #                                  [ 5.]
    #                                  [ 6.]
    #                                  [ 7.]
    #                                  [ 8.]]]`.
    reshaped_target_support = tf.tile(target_support[:, None], [batch_size, 1])
    reshaped_target_support = tf.reshape(reshaped_target_support,
                                         [batch_size, num_dims, 1])
    # numerator = `|clipped_support - z_i|` in Eq7.
    # Ex: `numerator = [[[[ 0.  0.  0.  2.  4.]
    #                     [ 1.  1.  1.  1.  3.]
    #                     [ 2.  2.  2.  0.  2.]
    #                     [ 3.  3.  3.  1.  1.]
    #                     [ 4.  4.  4.  2.  0.]]
    #                    [[ 0.  0.  0.  1.  2.]
    #                     [ 1.  1.  1.  0.  1.]
    #                     [ 2.  2.  2.  1.  0.]
    #                     [ 3.  3.  3.  2.  1.]
    #                     [ 4.  4.  4.  3.  2.]]]]`.
    numerator = tf.abs(tiled_support - reshaped_target_support)
    quotient = 1 - (numerator / delta_z)
    # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7.
    # Ex: `clipped_quotient = [[[[ 1.  1.  1.  0.  0.]
    #                            [ 0.  0.  0.  0.  0.]
    #                            [ 0.  0.  0.  1.  0.]
    #                            [ 0.  0.  0.  0.  0.]
    #                            [ 0.  0.  0.  0.  1.]]
    #                           [[ 1.  1.  1.  0.  0.]
    #                            [ 0.  0.  0.  1.  0.]
    #                            [ 0.  0.  0.  0.  1.]
    #                            [ 0.  0.  0.  0.  0.]
    #                            [ 0.  0.  0.  0.  0.]]]]`.
    clipped_quotient = tf.clip_by_value(quotient, 0, 1)
    # Ex: `weights = [[ 0.1  0.6  0.1  0.1  0.1]
    #                 [ 0.1  0.2  0.5  0.1  0.1]]`.
    weights = weights[:, None, :]
    # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))`
    # in Eq7.
    # Ex: `inner_prod = [[[[ 0.1  0.6  0.1  0.  0. ]
    #                      [ 0.   0.   0.   0.  0. ]
    #                      [ 0.   0.   0.   0.1 0. ]
    #                      [ 0.   0.   0.   0.  0. ]
    #                      [ 0.   0.   0.   0.  0.1]]
    #                     [[ 0.1  0.2  0.5  0.  0. ]
    #                      [ 0.   0.   0.   0.1 0. ]
    #                      [ 0.   0.   0.   0.  0.1]
    #                      [ 0.   0.   0.   0.  0. ]
    #                      [ 0.   0.   0.   0.  0. ]]]]`.
    inner_prod = clipped_quotient * weights
    # Ex: `projection = [[ 0.8 0.0 0.1 0.0 0.1]
    #                    [ 0.8 0.1 0.1 0.0 0.0]]`.
    projection = tf.reduce_sum(inner_prod, 3)
    projection = tf.reshape(projection, [batch_size, num_dims])
    return projection
예제 #12
0
def get_train_op(loss,
                 learning_rate=0.001,
                 lr_decay_steps=10000,
                 lr_decay_rate=0.98,
                 gradient_clip_norm=3.0,
                 use_tpu=True,
                 variables=None):
    """Get training operation with gradient clipping and learning rate decay.

  Distilled from tf.contrib.layers.optimize_loss().
  Args:
    loss: Scalar tensor of the loss function.
    learning_rate: Scalar initial learning rate.
    lr_decay_steps: Exponential decay timescale.
    lr_decay_rate: Exponential decay magnitude.
    gradient_clip_norm: Global norm by which to scale gradients.
    use_tpu: Use tpu for training.
    variables: List of variables to optimize. tf.trainable_variables() if None.

  Returns:
    train_op: Operation that runs one iteration of training.
  """
    global_step = tf.train.get_or_create_global_step()

    with tf.variable_scope('training', values=[loss, global_step]):
        # Make sure update ops run before computing loss.
        update_ops = list(set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
        with tf.control_dependencies(update_ops):
            loss = tf.identity(loss)

        # Learning rate variable, with decay.
        learning_rate_decay_fn = functools.partial(tf.train.exponential_decay,
                                                   decay_steps=lr_decay_steps,
                                                   decay_rate=lr_decay_rate,
                                                   staircase=True)
        lr = tf.get_variable(
            'learning_rate', [],
            trainable=False,
            initializer=tf.constant_initializer(learning_rate))
        lr = learning_rate_decay_fn(lr, global_step)

        # Optimizer.
        opt = tf.train.AdamOptimizer(lr)
        if use_tpu:
            opt = tf.tpu.CrossShardOptimizer(opt)

        # All trainable variables, if specific variables are not specified.
        if variables is None:
            variables = tf.trainable_variables()

        # Compute gradients.
        gradients = opt.compute_gradients(loss,
                                          variables,
                                          colocate_gradients_with_ops=False)

        # Optionally clip gradients by global norm.
        if isinstance(gradient_clip_norm, float):
            gradients = _clip_gradients_by_norm(gradients, gradient_clip_norm)

        # Create gradient updates.
        grad_updates = opt.apply_gradients(gradients,
                                           global_step=global_step,
                                           name='train')

        # Ensure the train_op computes grad_updates.
        with tf.control_dependencies([grad_updates]):
            train_op = tf.identity(loss)

        return train_op
예제 #13
0
    def detailed_forward_pass(self, data):
        """Returns all information from a forward pass of the `OptimizationLearner`.

    Args:
      data: A `meta_dataset.providers.Episode` containing the data for the
        episode.

    Returns:
      A `collections.NamedTuple` that contains the results of the forward pass.
    """
        # Loop initialization.
        init_loop_variables = self.task_parameters
        init_loop_variable_refs = [
            v.experimental_ref() for v in init_loop_variables
        ]

        # Construct ops for data-dependent episodic initialization.
        episodic_init_ops = self.episodic_init_ops(
            labels=data.support_labels,
            embeddings=self.embedding_fn(data.support_images, training=True),
            task_parameters=init_loop_variables,
        )

        def _forward_pass(iteration_idx_, variables_mapping_, images_,
                          onehot_labels_):
            """Helper function to compute the outputs of a forward pass."""

            with self.embedding_fn.reparameterize(variables_mapping_):
                # TODO(eringrant): Implement non-transductive batch normalization (i.e.,
                # pass the support set statistics through the query set forward pass.
                embeddings_ = self.embedding_fn(images_, training=True)

            # TODO(eringrant): `head_fn` is an attribute of the subclass.
            with self.head_fn.reparameterize(variables_mapping_):
                predictions_ = self.head_fn(embeddings_)[:, :data.way]

            accuracy_ = tf.reduce_mean(input_tensor=self.compute_accuracy(
                onehot_labels=onehot_labels_, predictions=predictions_))

            inner_objective_ = self.inner_objective(
                onehot_labels=onehot_labels_,
                predictions=predictions_,
                iteration_idx=iteration_idx_)

            outer_objective_ = self.outer_objective(
                onehot_labels=onehot_labels_,
                predictions=predictions_,
            )

            return ForwardPass(
                embeddings=embeddings_,
                predictions=predictions_,
                inner_objective_value=inner_objective_,
                outer_objective_value=outer_objective_,
                accuracy=accuracy_,
            )

        def _objective_fn(loop_variables_, iteration_idx_):
            """Evaluate the support set objective given `loop_variables_`."""

            # Get attribute paths for the loop_variables.
            loop_variables_mapping_ = dict(
                zip(init_loop_variable_refs, loop_variables_))

            adaptation_support_results = _forward_pass(
                iteration_idx_=iteration_idx_,
                variables_mapping_=loop_variables_mapping_,
                images_=data.support_images,
                onehot_labels_=data.onehot_support_labels)

            return adaptation_support_results.inner_objective_value

        def _e_step(loop_variables_):
            """Evaluate expectations given `loop_variables_`."""

            # Get attribute paths for the loop_variables.
            loop_variables_dict_ = dict(
                zip(init_loop_variable_refs, loop_variables_))

            with self.embedding_fn.reparameterize(loop_variables_dict_):
                # TODO(eringrant): training to True for normalization with batch stats.
                # Figure out the appropriate way to pass this around.
                train_embeddings_ = self.embedding_fn(data.train_images,
                                                      training=True)

            class_embeddings_ = learner_base.class_specific_data(
                data.onehot_train_labels, train_embeddings_, self.logit_dim)

            def _compute_responsibilities(examples_, class_idx):
                train_predictions_ = tf.squeeze(self.head_fn(
                    embeddings=examples_,
                    components=True,
                    class_idx=[class_idx]),
                                                axis=1)
                return tf.nn.softmax(train_predictions_, axis=-1)

            with self.head_fn.reparameterize(loop_variables_dict_):
                class_responsibilities_ = [
                    _compute_responsibilities(embeddings_, class_idx=i)
                    for i, embeddings_ in enumerate(class_embeddings_)
                ]

            return class_embeddings_, class_responsibilities_

        def _m_step(preupdate_vars, all_embeddings_, all_responsibilities_):
            """Compute parameter estimates given `loop_variables_`."""

            means, log_scales, logits = zip(
                *map(reparameterizable_distributions.fit_gaussian_mixture,
                     all_embeddings_, all_responsibilities_,
                     itertools.repeat(self.head_fn.damping)))

            def flatten(x):
                return list(itertools.chain.from_iterable(x))

            means = flatten(means)
            log_scales = flatten(log_scales)
            logits = flatten(logits)

            if not self.head_fn.estimate_loc:
                means = [None for _ in means]

            if not self.head_fn.estimate_scale:
                log_scales = [None for _ in log_scales]

            if not self.head_fn.estimate_logits:
                logits = [None for _ in logits]

            updated_vars = means + log_scales + logits

            # Replace constant variables.
            # TODO(eringrant): This interface differs from just excluding these
            # variables from `task_variables`.
            no_none_updated_vars = []
            for preupdate_var, updated_var in zip(preupdate_vars,
                                                  updated_vars):
                if updated_var is None:
                    no_none_updated_vars.append(preupdate_var)
                else:
                    no_none_updated_vars.append(updated_var)

            # TODO(eringrant): This assumes an ordering of mean, log_scales,
            # mixing_logits.
            return no_none_updated_vars

        # Loop body.
        with tf.control_dependencies(episodic_init_ops):

            # Inner loop of expectation maximization.
            num_em_steps = self.getattr('num_em_steps', 0)
            if num_em_steps > 0:
                loop_variables = em_loop(num_updates=self.num_em_steps,
                                         e_step=_e_step,
                                         m_step=_m_step,
                                         variables=loop_variables)

            # Inner loop of gradient-based optimization.
            num_optimizer_steps = (self.num_update_steps +
                                   (self.additional_evaluation_update_steps
                                    if not self.is_training else 0))
            if num_optimizer_steps > 0:
                # pylint: disable=no-value-for-parameter
                final_loop_variables = optimizer_loop(
                    num_updates=num_optimizer_steps,
                    objective_fn=_objective_fn,
                    update_fn=self.update_fn,
                    variables=init_loop_variables,
                    first_order=self.first_order,
                    clip_grad_norm=self.clip_grad_norm,
                )
                # pylint: enable=no-value-for-parameter

            # If no inner loop adaptation is performed, ensure the episodic
            # initialization is still part of the graph via a control dependency.
            if num_optimizer_steps + num_em_steps == 0:
                loop_variables = [tf.identity(v) for v in init_loop_variables]

        # Get variable references to use when remapping the loop_variables.
        init_loop_variables_mapping = dict(
            zip(init_loop_variable_refs, init_loop_variables))
        final_loop_variables_mapping = dict(
            zip(init_loop_variable_refs, final_loop_variables))

        # Collect statistics about the inner optimization.
        with tf.compat.v1.name_scope('pre-adaptation'):
            with tf.compat.v1.name_scope('support'):
                pre_adaptation_support_results = _forward_pass(
                    iteration_idx_=0,
                    variables_mapping_=init_loop_variables_mapping,
                    images_=data.support_images,
                    onehot_labels_=data.onehot_support_labels)

            with tf.compat.v1.name_scope('query'):
                pre_adaptation_query_results = _forward_pass(
                    iteration_idx_=0,
                    variables_mapping_=init_loop_variables_mapping,
                    images_=data.query_images,
                    onehot_labels_=data.onehot_query_labels)

        with tf.compat.v1.name_scope('post-adaptation'):
            with tf.compat.v1.name_scope('support'):
                post_adaptation_support_results = _forward_pass(
                    iteration_idx_=num_optimizer_steps,
                    variables_mapping_=final_loop_variables_mapping,
                    images_=data.support_images,
                    onehot_labels_=data.onehot_support_labels,
                )

            with tf.compat.v1.name_scope('query'):
                post_adaptation_query_results = _forward_pass(
                    iteration_idx_=num_optimizer_steps,
                    variables_mapping_=final_loop_variables_mapping,
                    images_=data.query_images,
                    onehot_labels_=data.onehot_query_labels,
                )

        def _support_module_objective_fn(module_variables_,
                                         module_variable_refs_):
            """Evaluate the query set objective given `module_variables_`."""
            # Use the values of the parameters at convergence as the default value.
            variables_mapping_ = final_loop_variables_mapping.copy()

            # Loop over and replace the module-specific variables.
            for module_variable_ref, module_variable in zip(
                    module_variable_refs_, module_variables_):
                variables_mapping_[module_variable_ref] = module_variable

            adaptation_query_results = _forward_pass(
                iteration_idx_=num_optimizer_steps,
                variables_mapping_=variables_mapping_,
                images_=data.support_images,
                onehot_labels_=data.onehot_support_labels,
            )

            return adaptation_query_results.inner_objective_value

        def _query_module_objective_fn(module_variables_,
                                       module_variable_refs_):
            """Evaluate the query set objective given `module_variables_`."""
            # Use the values of the parameters at convergence as the default value.
            variables_mapping_ = final_loop_variables_mapping.copy()

            # Loop over and replace the module-specific variables.
            for module_variable_ref, module_variable in zip(
                    module_variable_refs_, module_variables_):
                variables_mapping_[module_variable_ref] = module_variable

            adaptation_query_results = _forward_pass(
                iteration_idx_=num_optimizer_steps,
                variables_mapping_=variables_mapping_,
                images_=data.query_images,
                onehot_labels_=data.onehot_query_labels)

            return adaptation_query_results.inner_objective_value

        return Adaptation(
            pre_adaptation_support_results=pre_adaptation_support_results,
            post_adaptation_support_results=post_adaptation_support_results,
            pre_adaptation_query_results=pre_adaptation_query_results,
            post_adaptation_query_results=post_adaptation_query_results,
            objective_fn=_objective_fn,
            support_module_objective_fn=_support_module_objective_fn,
            query_module_objective_fn=_query_module_objective_fn,
            forward_pass_fn=_forward_pass,
            init_loop_variables_mapping=init_loop_variables_mapping,
            final_loop_variables_mapping=final_loop_variables_mapping,
        )
예제 #14
0
    def forward_pass(self, data):
        """Computes the query logits for the given episode `data`."""

        if self.film_init == 'scratch':
            self.film_selector = None
        elif self.film_init == 'imagenet':
            # Note: this makes the assumption that the first set of learned FiLM
            # parameters corresponds to the ImageNet dataset. Otherwise, the
            # following line should be changed appropriately.
            self.film_selector = 0
        elif self.film_init in ['blender', 'blender_hard']:
            dataset_logits = functional_backbones.dataset_classifier(
                data.support_images)
            if self.film_init == 'blender_hard':
                # Select only the argmax entry.
                self.film_selector = tf.one_hot(
                    tf.math.argmax(dataset_logits, axis=-1),
                    depth=tf.shape(dataset_logits)[1])
            else:
                # Take a convex combination.
                self.film_selector = tf.nn.softmax(dataset_logits, axis=-1)

        if self.num_steps:
            # Initial forward pass, required for the `unused_op` below and for placing
            # variables in tf.trainable_variables() for the below block to pick up.
            loss = self._compute_losses(data, compute_on_query=False)['loss']

            # Pick out the variables to optimize.
            self.opt_vars = []
            for var in tf.trainable_variables():
                if '_for_film_learner' in var.name:
                    self.opt_vars.append(var)
            tf.logging.info('FiLMLearner will optimize vars: {}'.format(
                self.opt_vars))

        for i in range(self.num_steps):
            if i == 0:
                # Re-initialize the variables to optimize for the new episode, to ensure
                # the FiLM parameters aren't re-used across tasks of a given dataset.
                vars_reset = tf.variables_initializer(var_list=self.opt_vars)
                # Adam related variables are created when minimize() is called.
                # We create an unused op here to put all adam varariables under
                # the 'adam_opt' namescope and create a reset op to reinitialize
                # these variables before the first finetune step.
                with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE):
                    unused_op = self.opt.minimize(loss, var_list=self.opt_vars)
                adam_reset = tf.variables_initializer(self.opt.variables())

                with tf.control_dependencies([vars_reset, adam_reset, loss] +
                                             self.opt_vars):
                    print_op = tf.no_op()
                    if self.debug_log:
                        print_op = tf.print([
                            'step: %d' % i, self.opt_vars[0][0], 'loss:', loss
                        ],
                                            summarize=-1)

                    with tf.control_dependencies([print_op]):
                        # Get the train op.
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

            else:
                with tf.control_dependencies([train_op, loss, acc] +
                                             self.opt_vars +
                                             [query_loss, query_acc] *
                                             int(self.debug_log)):

                    print_op = tf.no_op()
                    if self.debug_log:
                        print_list = [
                            '################',
                            'step: %d' % i,
                            self.opt_vars[0][0],
                            'support loss:',
                            loss,
                            'query loss:',
                            query_loss,
                            'support acc:',
                            acc,
                            'query acc:',
                            query_acc,
                        ]
                        print_op = tf.print(print_list)

                    with tf.control_dependencies([print_op]):
                        # Get the train op (the loss is returned just for printing).
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

        # Training is now over, compute the final query logits.
        dependency_list = [] if not self.num_steps else [train_op
                                                         ] + self.opt_vars
        with tf.control_dependencies(dependency_list):
            results = self._compute_losses(data, compute_on_query=True)
            (loss, query_loss, query_logits, acc,
             query_acc) = (results['loss'], results['query_loss'],
                           results['query_logits'], results['acc'],
                           results['query_acc'])

            print_op = tf.no_op()
            if self.debug_log:
                print_op = tf.print([
                    'Done training',
                    'support loss:',
                    loss,
                    'query loss:',
                    query_loss,
                    'support acc:',
                    acc,
                    'query acc:',
                    query_acc,
                ])
            with tf.control_dependencies([print_op]):
                query_logits = tf.identity(query_logits)

        return query_logits
예제 #15
0
 def get_updated_global_step(self):
     with tf.control_dependencies([self.train_op]):
         global_step = tf.identity(tf.train.get_global_step())
     return global_step
예제 #16
0
def geometric_augmentation(images,
                           flow = None,
                           mask = None,
                           crop_height = 640,
                           crop_width = 640,
                           probability_flip_left_right = 0.5,
                           probability_flip_up_down = 0.1,
                           probability_scale = 0.8,
                           probability_relative_scale = 0.,
                           probability_stretch = 0.8,
                           probability_rotation = 0.0,
                           probability_relative_rotation = 0.0,
                           probability_crop_offset = 0.0,
                           min_bound_scale = -0.2,
                           max_bound_scale = 0.6,
                           max_strech_scale = 0.2,
                           min_bound_relative_scale = -0.1,
                           max_bound_relative_scale = 0.1,
                           max_rotation_deg = 15,
                           max_relative_rotation_deg = 3,
                           max_relative_crop_offset = 5,
                           return_full_scale=False):

  """Applies geometric augmentations to an image pair and corresponding flow.

  Args:
    images: Image pair of shape [2, height, width, channels].
    flow: Corresponding forward flow field of shape [height, width, 2].
    mask: Mask indicating which positions in the flow field hold valid flow
      vectors of shape [height, width, 1]. Non-valid poisitions are encoded with
      0, valid positions with 1.
    crop_height: Height of the final augmented output.
    crop_width: Width of the final augmented output.
    probability_flip_left_right: Probability of applying left/right flip.
    probability_flip_up_down: Probability of applying up/down flip
    probability_scale: Probability of applying scale augmentation.
    probability_relative_scale: Probability of applying scale augmentation to
      only the second frame of the the image pair.
    probability_stretch: Probability of applying stretch augmentation (scale
      without keeping the aspect ratio).
    probability_rotation: Probability of applying rotation augmentation.
    probability_relative_rotation: Probability of applying rotation augmentation
      to only the second frame of the the image pair.
    probability_crop_offset: Probability of applying a relative offset while
      cropping.
    min_bound_scale: Defines the smallest possible scaling factor as
      2**min_bound_scale.
    max_bound_scale: Defines the largest possible scaling factor as
      2**max_bound_scale.
    max_strech_scale: Defines the smallest and largest possible streching factor
      as 2**-max_strech_scale and 2**max_strech_scale.
    min_bound_relative_scale: Defines the smallest possible scaling factor for
      the relative scaling as 2**min_bound_relative_scale.
    max_bound_relative_scale: Defines the largest possible scaling factor for
      the relative scaling as 2**max_bound_relative_scale.
    max_rotation_deg: Defines the maximum angle of rotation in degrees.
    max_relative_rotation_deg: Defines the maximum angle of rotation in degrees
      for the relative rotation.
    max_relative_crop_offset: Defines the maximum relative offset in pixels for
      cropping.
    return_full_scale: bool. If this is passed, the full size images will be
      returned in addition to the geometrically augmented (cropped and / or
      resized) images. In addition to the resized images, the crop height,
      width, and any padding applied will be returned.

  Returns:
    if return_full_scale is False:
      Augmented images, flow and mask (if not None).
    if return_full_scale is True:
      Augmented images, flow, mask, full_size_images, crop_h, crop_w, pad_h,
       and pad_w.
  """

  # apply geometric augmentation
  if probability_flip_left_right > 0:
    images, flow, mask = random_flip_left_right(
        images, flow, mask, probability_flip_left_right)

  if probability_flip_up_down > 0:
    images, flow, mask = random_flip_up_down(
        images, flow, mask, probability_flip_up_down)

  if probability_scale > 0 or probability_stretch > 0:
    images, flow, mask = random_scale(
        images,
        flow,
        mask,
        min_scale=min_bound_scale,
        max_scale=max_bound_scale,
        max_strech=max_strech_scale,
        probability_scale=probability_scale,
        probability_strech=probability_stretch)

  if probability_relative_scale > 0:
    images, flow, mask = random_scale_second(
        images, flow, mask,
        min_scale=min_bound_relative_scale,
        max_scale=max_bound_relative_scale,
        probability_scale=probability_relative_scale)

  if probability_rotation > 0:
    images, flow, mask = random_rotation(
        images, flow, mask,
        probability=probability_rotation,
        max_rotation=max_rotation_deg, not_empty_crop=True)

  if probability_relative_rotation > 0:
    images, flow, mask = random_rotation_second(
        images, flow, mask,
        probability=probability_relative_rotation,
        max_rotation=max_relative_rotation_deg, not_empty_crop=True)

  images_uncropped = images
  images, flow, mask, offset_h, offset_w = random_crop(
      images, flow, mask, crop_height, crop_width,
      relative_offset=max_relative_crop_offset,
      probability_crop_offset=probability_crop_offset)
  # Add 100 / 200 pixels to crop height / width for full scale warp
  pad_to_size_h = crop_height + 200
  pad_to_size_w = crop_width + 400
  if return_full_scale:
    if pad_to_size_w:
      uncropped_shape = tf.shape(images_uncropped)
      if images.shape[1] > uncropped_shape[1] or images.shape[
          2] > uncropped_shape[2]:
        images_uncropped = images
        uncropped_shape = tf.shape(images_uncropped)
        offset_h = tf.zeros_like(offset_h)
        offset_w = tf.zeros_like(offset_w)

      if uncropped_shape[1] > pad_to_size_h:
        crop_ht = offset_h - (200 // 2)
        crop_hb = offset_h + crop_height + (200 // 2)
        crop_hb += tf.maximum(0, -crop_ht)
        crop_ht -= tf.maximum(0, -(uncropped_shape[1] - crop_hb))
        crop_ht = tf.maximum(crop_ht, 0)
        crop_hb = tf.minimum(crop_hb, uncropped_shape[1])
        offset_h -= crop_ht
        images_uncropped = images_uncropped[:, crop_ht:crop_hb, :, :]

      if uncropped_shape[2] > pad_to_size_w:
        crop_wt = offset_w - (400 // 2)
        crop_wb = offset_w + crop_width + (400 // 2)
        crop_wb += tf.maximum(0, -crop_wt)
        crop_wt -= tf.maximum(0, -(uncropped_shape[2] - crop_wb))
        crop_wt = tf.maximum(crop_wt, 0)
        crop_wb = tf.minimum(crop_wb, uncropped_shape[2])
        offset_w -= crop_wt
        images_uncropped = images_uncropped[:, :, crop_wt:crop_wb, :]

      uncropped_shape = tf.shape(images_uncropped)
      # remove remove_pixels_w from the width while keeping the crop centered
      pad_h = pad_to_size_h - uncropped_shape[1]
      pad_w = pad_to_size_w - uncropped_shape[2]
      with tf.control_dependencies([
          tf.compat.v1.assert_greater_equal(pad_h, 0),
          tf.compat.v1.assert_greater_equal(pad_w, 0)
      ]):
        images_uncropped = tf.pad(images_uncropped,
                                  [[0, 0], [pad_h, 0], [pad_w, 0], [0, 0]])
      images_uncropped = tf.ensure_shape(images_uncropped,
                                         [2, pad_to_size_h, pad_to_size_w, 3])
    return images, flow, mask, images_uncropped, offset_h, offset_w, pad_h, pad_w

  return images, flow, mask
예제 #17
0
  def compute_logits(self, data):
    """Computes the class logits for the episode.

    Args:
      data: A `meta_dataset.providers.Episode`.

    Returns:
      The query set logits as a [num_query_images, way] matrix.

    Raises:
      ValueError: Distance must be one of l2 or cosine.
    """
    # ------------------------ Finetuning -------------------------------
    # Possibly make copies of embedding variables, if they will get modified.
    # This is for making temporary-only updates to the embedding network
    # which will not persist after the end of the episode.
    make_copies = self.finetune_all_layers

    # TODO(eringrant): Reduce the number of times the embedding function graph
    # is built with the same input.
    support_embeddings_params_moments = self.embedding_fn(
        data.support_images, self.is_training)
    support_embeddings = support_embeddings_params_moments['embeddings']
    support_embeddings_var_dict = support_embeddings_params_moments['params']

    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         support_embeddings_var_dict, make_copies)
    embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops)

    # Compute the initial training loss (only for printing purposes). This
    # line is also needed for adding the fc variables to the graph so that the
    # tf.all_variables() line below detects them.
    logits = self._fc_layer(support_embeddings)[:, 0:data.way]
    finetune_loss = self.compute_loss(
        onehot_labels=data.onehot_support_labels,
        predictions=logits,
    )

    # Decide which variables to finetune.
    fc_vars, vars_to_finetune = [], []
    for var in tf.trainable_variables():
      if 'fc_finetune' in var.name:
        fc_vars.append(var)
        vars_to_finetune.append(var)
    if self.finetune_all_layers:
      vars_to_finetune.extend(embedding_vars)
    logging.info('Finetuning will optimize variables: %s', vars_to_finetune)

    for i in range(self.num_finetune_steps):
      if i == 0:
        # Randomly initialize the fc layer.
        fc_reset = tf.variables_initializer(var_list=fc_vars)
        # Adam related variables are created when minimize() is called.
        # We create an unused op here to put all adam varariables under
        # the 'adam_opt' namescope and create a reset op to reinitialize
        # these variables before the first finetune step.
        adam_reset = tf.no_op()
        if self.finetune_with_adam:
          with tf.variable_scope('adam_opt'):
            unused_op = self.finetune_opt.minimize(
                finetune_loss, var_list=vars_to_finetune)
          adam_reset = tf.variables_initializer(self.finetune_opt.variables())
        with tf.control_dependencies(
            [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] +
            vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:',
                finetune_loss
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

      else:
        with tf.control_dependencies([finetune_op, finetune_loss] +
                                     vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i,
                vars_to_finetune[0][0, 0],
                'loss:',
                finetune_loss,
                'accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_support_labels, predictions=logits),
                'query accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_query_labels, predictions=query_logits),
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

    # Finetuning is now over, compute the query performance using the updated
    # fc layer, and possibly the updated embedding network.
    with tf.control_dependencies([finetune_op] + vars_to_finetune):
      query_embeddings = self.embedding_fn(
          data.query_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']
      query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

      if self.debug_log:
        # The train logits are computed only for printing.
        support_embeddings = self.embedding_fn(
            data.support_images,
            self.is_training,
            params=collections.OrderedDict(
                zip(embedding_vars_keys, embedding_vars)),
            reuse=True)['embeddings']
        logits = self._fc_layer(support_embeddings)[:, 0:data.way]

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print([
            'accuracy:',
            self.compute_accuracy(
                labels=data.onehot_support_labels, predictions=logits),
            'query accuracy:',
            self.compute_accuracy(
                labels=data.onehot_query_labels, predictions=query_logits),
        ])
      with tf.control_dependencies([print_op]):
        query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

    return query_logits
예제 #18
0
def train_q(dataset,
            policy,
            optimizer=None,
            pack_transition_fn=None,
            q_graph_fn=None,
            log_dir=None,
            master='',
            task=0,
            training_steps=None,
            max_training_steps=100000,
            reuse=False,
            init_checkpoint=None,
            update_target_every_n_steps=50,
            log_every_n_steps=None,
            save_checkpoint_steps=500,
            save_summaries_steps=500):
    """Self-contained learning loop for offline Q-learning.

  Code inspired by OpenAI Baselines' deepq.build_train. This function is
  compatible with discrete Q-learning graphs, continuous Q learning graphs, and
  SARSA.

  Args:
    dataset: tf.data.Dataset providing transitions.
    policy: Instance of TFDQNPolicy class that provides functor for building the
      critic function.
    optimizer: Optional instance of an optimizer. If not specified, creates an
      AdamOptimizer using the default constructor.
    pack_transition_fn: Optional function that performs additional processing
      of the transition. This is a convenience method for ad-hoc manipulation of
      transition data passed to the learning function after parsing.
    q_graph_fn: Function used to construct training objectives w.r.t. critic
      outputs.
    log_dir: Where to save model checkpoints and tensorboard summaries.
    master: Optional address of master worker. Specify this when doing
      distributed training.
    task: Optional worker task for distributed training. Defaults to solo master
      task on a single machine.
    training_steps: Optional number of steps to run training before terminating
      early. Max_training_steps remains unchanged - training will terminate
      after max_training_steps whether or not training_steps is specified.
    max_training_steps: maximum number of training iters.
    reuse: If True, reuse existing variables for all declared variables by this
      function.
    init_checkpoint: Optional checkpoint to restore prior to training. If not
      provided, variables are initialized using global_variables_initializer().
    update_target_every_n_steps: How many global steps (training) between
      copying the Q network weights (scope='q_func') to target network
      (scope='target_q_func').
    log_every_n_steps: How many global steps between logging loss tensors.
    save_checkpoint_steps: How many global steps between saving TF variables
      to a checkpoint file.
    save_summaries_steps: How many global steps between saving TF summaries.

  Returns:
    (int) Current `global_step` reached after training for training_steps, or
    `max_training_steps` if `global_step` has reached `max_training_steps`.

  Raises:
    ValueError: If a batch of transitions is empty or the zeroth element is
      empty, when it's supposed to be of length batch_size.
  """
    data_iterator = dataset.make_one_shot_iterator()

    transition = data_iterator.get_next()
    if pack_transition_fn:
        transition = pack_transition_fn(transition)

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer()

    q_func = policy.get_q_func(is_training=True, reuse=reuse)
    loss, all_summaries = q_graph_fn(q_func, transition)

    q_func_vars = contrib_framework.get_trainable_variables(scope='q_func')
    target_q_func_vars = contrib_framework.get_trainable_variables(
        scope='target_q_func')
    global_step = tf.train.get_or_create_global_step()

    # Only optimize q_func and update its batchnorm params.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='q_func')
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss,
                                      global_step=global_step,
                                      var_list=q_func_vars)

    chief_hooks = []
    hooks = []
    # Save summaries periodically.
    if save_summaries_steps is not None:
        chief_hooks.append(
            tf.train.SummarySaverHook(save_steps=save_summaries_steps,
                                      output_dir=log_dir,
                                      summary_op=all_summaries))

    # Stop after training_steps
    if max_training_steps:
        hooks.append(tf.train.StopAtStepHook(last_step=max_training_steps))

    # Report if loss tensor is NaN.
    hooks.append(tf.train.NanTensorHook(loss))

    if log_every_n_steps is not None:
        tensor_dict = {'global_step': global_step, 'train loss': loss}
        chief_hooks.append(
            tf.train.LoggingTensorHook(tensor_dict,
                                       every_n_iter=log_every_n_steps))

        # Measure how fast we are training per sec and save to summary.
        chief_hooks.append(
            tf.train.StepCounterHook(every_n_steps=log_every_n_steps,
                                     output_dir=log_dir))

    # If target network exists, periodically update target Q network with new
    # weights (frozen target network). We hack this by
    # abusing a LoggingTensorHook for this.
    if target_q_func_vars and update_target_every_n_steps is not None:
        update_target_expr = []
        for var, var_t in zip(sorted(q_func_vars, key=lambda v: v.name),
                              sorted(target_q_func_vars,
                                     key=lambda v: v.name)):
            update_target_expr.append(var_t.assign(var))
        update_target_expr = tf.group(*update_target_expr)

        with tf.control_dependencies([update_target_expr]):
            update_target = tf.constant(0)
        chief_hooks.append(
            tf.train.LoggingTensorHook(
                {'update_target': update_target},
                every_n_iter=update_target_every_n_steps))

    # Save checkpoints periodically, save all of them.
    saver = tf.train.Saver(max_to_keep=None)
    chief_hooks.append(
        tf.train.CheckpointSaverHook(log_dir,
                                     save_steps=save_checkpoint_steps,
                                     saver=saver,
                                     checkpoint_basename='model.ckpt'))

    # Save our experiment params to checkpoint dir.
    chief_hooks.append(
        gin.tf.GinConfigSaverHook(log_dir, summarize_config=True))

    session_config = tf.ConfigProto(log_device_placement=False)

    init_fn = None
    if init_checkpoint:
        assign_fn = contrib_framework.assign_from_checkpoint_fn(
            init_checkpoint, contrib_framework.get_model_variables())
        init_fn = lambda _, sess: assign_fn(sess)
    scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn)
    with tf.train.MonitoredTrainingSession(
            master=master,
            is_chief=(task == 0),
            config=session_config,
            checkpoint_dir=log_dir,
            scaffold=scaffold,
            hooks=hooks,
            chief_only_hooks=chief_hooks) as sess:
        np_step = 0
        while not sess.should_stop():
            np_step, _ = sess.run([global_step, train_op])
            if training_steps and np_step % training_steps == 0:
                break
        done = np_step >= max_training_steps
    return np_step, done
예제 #19
0
  def forward_pass(self, data):
    """Computes the test logits of MAML.

    Args:
      data: A `meta_dataset.providers.Episode` containing the data for the
        episode.

    Returns:
      The output logits for the query data in this episode.
    """
    # Have to use one-hot labels since sparse softmax doesn't allow
    # second derivatives.
    support_embeddings_ = self.embedding_fn(
        data.support_images, self.is_training, reuse=tf.AUTO_REUSE)
    support_embeddings = support_embeddings_['embeddings']
    embedding_vars_dict = support_embeddings_['params']

    # TODO(eringrant): Refactor to make use of
    # `functional_backbones.linear_classifier`, which allows Gin-configuration.
    with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE):
      embedding_depth = support_embeddings.shape.as_list()[-1]
      fc_weights = functional_backbones.weight_variable(
          [embedding_depth, self.logit_dim],
          weight_decay=self.classifier_weight_decay)
      fc_bias = functional_backbones.bias_variable([self.logit_dim])

    # A list of variable names, a list of corresponding Variables, and a list
    # of operations (possibly empty) that creates a copy of each Variable.
    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         embedding_vars_dict, make_copies=not self.is_training)

    # A Variable for the weights of the fc layer, a Variable for the bias of the
    # fc layer, and a list of operations (possibly empty) that copies them.
    (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops(
        fc_weights, fc_bias, make_copies=not self.is_training)

    fc_vars = [fc_weights, fc_bias]
    num_embedding_vars = len(embedding_vars)
    num_fc_vars = len(fc_vars)

    def _cond(step, *args):
      del args
      num_steps = self.num_update_steps
      if not self.is_training:
        num_steps += self.additional_evaluation_update_steps
      return step < num_steps

    def _body(step, *args):
      """The inner update loop body."""
      updated_embedding_vars = args[0:num_embedding_vars]
      updated_fc_vars = args[num_embedding_vars:num_embedding_vars +
                             num_fc_vars]
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['embeddings']

      updated_fc_weights, updated_fc_bias = updated_fc_vars
      support_logits = tf.matmul(support_embeddings,
                                 updated_fc_weights) + updated_fc_bias

      support_logits = support_logits[:, 0:data.way]
      loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels,
                                             support_logits)

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss])

      with tf.control_dependencies([print_op]):
        updated_embedding_vars = gradient_descent_step(
            loss, updated_embedding_vars, self.first_order,
            self.adapt_batch_norm, self.alpha, False)['updated_vars']
        updated_fc_vars = gradient_descent_step(loss, updated_fc_vars,
                                                self.first_order,
                                                self.adapt_batch_norm,
                                                self.alpha,
                                                False)['updated_vars']

        step = step + 1
      return tuple([step] + list(updated_embedding_vars) +
                   list(updated_fc_vars))

    # MAML meta updates using query set examples from an episode.
    if self.zero_fc_layer:
      # To account for variable class sizes, we initialize the output
      # weights to zero. See if truncated normal initialization will help.
      zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights))
      zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias))
      fc_vars_init_ops = [zero_weights_op, zero_bias_op]
    else:
      fc_vars_init_ops = fc_vars_copy_ops

    if self.proto_maml_fc_layer_init:
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']

      prototypes = metric_learners.compute_prototypes(
          support_embeddings, data.onehot_support_labels)
      pmaml_fc_weights = self.proto_maml_fc_weights(
          prototypes, zero_pad_to_max_way=True)
      pmaml_fc_bias = self.proto_maml_fc_bias(
          prototypes, zero_pad_to_max_way=True)
      fc_vars = [pmaml_fc_weights, pmaml_fc_bias]

    # These control dependencies assign the value of each variable to a new copy
    # variable that corresponds to it. This is required at test time for
    # initilizing the copies as they are used in place of the original vars.
    with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops):
      # Make step a local variable as we don't want to save and restore it.
      step = tf.Variable(
          0,
          trainable=False,
          name='inner_step_counter',
          collections=[tf.GraphKeys.LOCAL_VARIABLES])
      loop_vars = [step] + embedding_vars + fc_vars
      step_and_all_updated_vars = tf.while_loop(
          _cond, _body, loop_vars, swap_memory=True)
      step = step_and_all_updated_vars[0]
      all_updated_vars = step_and_all_updated_vars[1:]
      updated_embedding_vars = all_updated_vars[0:num_embedding_vars]
      updated_fc_weights, updated_fc_bias = all_updated_vars[
          num_embedding_vars:num_embedding_vars + num_fc_vars]

    # Forward pass the training images with the updated weights in order to
    # compute the means and variances, to use for the query's batch norm.
    support_set_moments = None
    if not self.transductive_batch_norm:
      support_set_moments = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['moments']

    query_embeddings = self.embedding_fn(
        data.query_images,
        self.is_training,
        params=collections.OrderedDict(
            zip(embedding_vars_keys, updated_embedding_vars)),
        moments=support_set_moments,  # Use support set stats for batch norm.
        reuse=True,
        backprop_through_moments=self.backprop_through_moments)['embeddings']

    query_logits = (tf.matmul(query_embeddings, updated_fc_weights) +
                    updated_fc_bias)[:, 0:data.way]

    return query_logits