コード例 #1
0
    def get_avg_loss(self, learner, inner_batches, init_state):
        """Compute average loss for unroll."""
        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        def body(a, i):
            batch = self.get_batch(i * self.batches_per_step,
                                   batches=inner_batches)
            l = learner.meta_loss(init_state, loss_state=batch)
            return a + l, i + 1

        def cond(_, i):
            return tf.less(i, self.unroll_n_steps)

        a, _ = tf.while_loop(cond, body, [tf.constant(0.0), tf.constant(0)])
        return a / self.unroll_n_steps
コード例 #2
0
def build_evaluation_graph(loss_module_fn=gin.REQUIRED,
                           learner_fn=gin.REQUIRED):
  """Build the evaluation graph for inner-training."""
  global_step = tf.train.get_or_create_global_step()

  loss_module = loss_module_fn()
  learner, theta_mod = learner_fn(loss_module)

  initial_state = learner.initial_state()
  reset_state_op = learner.assign_state(initial_state)
  state = learner.current_state()
  state = tf_utils.force_copy(state)

  with tf.control_dependencies(nest.flatten(state)):
    last_loss, new_state = learner.loss_and_next_state(state)

  with tf.control_dependencies([last_loss] + nest.flatten(new_state)):
    train_op = learner.assign_state(new_state)

  update_global_step = global_step.assign_add(1)
  train_op = tf.group(train_op, update_global_step, name="train_op")

  load_vars = list(theta_mod.get_variables(tf.GraphKeys.GLOBAL_VARIABLES))
  meta_loss = learner.meta_loss(state)
  meta_loss = tf.Print(meta_loss, [meta_loss], "meta_loss")

  inner_loss = learner.inner_loss(state)
  inner_loss = tf.Print(inner_loss, [inner_loss], "inner_loss")
  # TODO(lmetz) this should only be scalars.
  train_op = tf.group(train_op, name="train_op")

  return {
      "train_op": train_op,
      "global_step": global_step,
      "init_op": reset_state_op,
      "checkpoint_vars": load_vars,
      "meta_loss": meta_loss,
      "train_loss": inner_loss,
      "current_state": state,
      "next_state": new_state,
  }
コード例 #3
0
    def get_phi_trajectory(self, learner, inner_batches=None, init_state=None):
        """Compute a inner-parameter trajectory."""
        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        def body(learner_state, ta, i):
            batch = self.get_batch(i, batches=inner_batches)
            _, next_state = learner.loss_and_next_state(learner_state,
                                                        loss_state=batch)

            # shift because this is the next state.
            next_ta = self._nest_bimap(lambda t, v: t.write(i + 1, v), ta,
                                       next_state)
            return next_state, next_ta, i + 1

        def cond(learner_state, ta, i):  # pylint: disable=unused-argument
            return tf.less(i, self.unroll_n_steps)

        def make_ta(x):
            ta = tf.TensorArray(dtype=x.dtype, size=self.unroll_n_steps + 1)
            return ta.write(0, x)

        _, ta, _ = tf.while_loop(cond, body, [
            init_state,
            nest.map_structure(make_ta, init_state),
            tf.constant(0)
        ])
        return nest.map_structure(lambda x: x.stack(), ta)
コード例 #4
0
def compute_meta_loss(learner,
                      unroll_n_steps,
                      init_state=None,
                      extra_loss_eval=5):
    """Helper function to compute the training objective.

  This function unrolls `unroll_n_steps` and accumulates the loss.
  Additionally, to lower variance, at each new state, an extra extra_loss_eval
  losses are computed and added to the loss.

  TODO(lmetz) a more rigorous anylisis of variance of gradients to pick these
  parameters.

  Args:
    learner: Learner instance
    unroll_n_steps: number of steps to unroll
    init_state: initial LearnerState
    extra_loss_eval: int
  Returns:
    meta_loss, final LearnerState
  """
    if init_state is None:
        init_state = learner.current_state()
        init_state = tf_utils.force_copy(init_state)

    current_state = (tf.constant(0., dtype=tf.float32), init_state)
    loss_and_next_state_fn = lambda (l, state): learner.loss_and_next_state(
        state)

    def accumulate_fn((l, s), a):
        if extra_loss_eval > 0:
            cond = lambda i, a: tf.less(i, extra_loss_eval)
            body = lambda i, a: (i + 1, a + learner.meta_loss(s))
            _, extra_losses = tf.while_loop(cond, body, loop_vars=[0, 0.])
            return a + extra_losses
        else:
            return a + l
コード例 #5
0
    def __call__(self,
                 learner,
                 meta_batches=None,
                 inner_batches=None,
                 init_state=None,
                 unroll_n_steps=None):
        if unroll_n_steps is None:
            unroll_n_steps = self.unroll_n_steps
        else:
            print("Using passed in unroll steps")

        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if meta_batches is None:
            meta_batches = self.meta_batches
        else:
            # convert the batches object to a tensorarray.
            def ml_to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.meta_loss_evals *
                                      self.unroll_n_steps).unstack(t)

            meta_batches = nest.map_structure(ml_to_ta, meta_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        current_state = (tf.constant(0, dtype=tf.int32),
                         tf.constant(0., dtype=tf.float32), init_state)

        def loss_and_next_state_fn((idx, l, state)):
            batch = self.get_batch(idx, batches=inner_batches)
            l, s = learner.loss_and_next_state(state, loss_state=batch)
            return (idx + 1, l, s)

        def accumulate_fn((idx, _, s), (a_meta, a_inner)):
            """Accumulate loss for fold learning process."""
            cond = lambda i, a: tf.less(i, self.meta_loss_evals)

            def body_meta(i, a):
                # minus 1 as this takes the following step.
                batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i,
                                       batches=meta_batches)
                return (i + 1, a + learner.meta_loss(s, loss_state=batch))

            _, extra_losses = tf.while_loop(cond, body_meta, loop_vars=[0, 0.])

            def body_inner(i, a):
                # minus 1 as this takes the following step.
                batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i,
                                       batches=meta_batches)
                return (i + 1, a + learner.inner_loss(s, loss_state=batch))

            _, inner_losses = tf.while_loop(cond,
                                            body_inner,
                                            loop_vars=[0, 0.])

            return a_meta + extra_losses, a_inner + inner_losses