def _feature_transformer(input_data):
     """Feature transformer at the end of training."""
     initial_variables = base_model.get_variables()
     replacement = collections.OrderedDict(
         utils.eqzip(initial_variables, new_variables))
     with variable_replace.variable_replace(replacement):
         features = base_model(input_data)
     return features
Ejemplo n.º 2
0
 def _feature_transformer(input_data):
   """Feature transformer at the end of training."""
   initial_variables = base_model.get_variables()
   replacement = collections.OrderedDict(
       utils.eqzip(initial_variables, new_variables))
   with variable_replace.variable_replace(replacement):
     features = base_model(input_data)
   return features
    def assign_state(self, base_model, next_state):
        var_ups = [
            v.assign(nv) for v, nv in utils.eqzip(base_model.get_variables(),
                                                  next_state.variables)
        ]

        opt_ups = self.opt.assign_state(next_state.opt_state)

        return tf.group(opt_ups, *var_ups)
Ejemplo n.º 4
0
  def assign_state(self, base_model, next_state):
    var_ups = [
        v.assign(nv) for v, nv in utils.eqzip(base_model.get_variables(),
                                              next_state.variables)
    ]

    opt_ups = self.opt.assign_state(next_state.opt_state)

    return tf.group(opt_ups, *var_ups)
Ejemplo n.º 5
0
 def compute_updates(self, xs, gs, learning_rates, state):
   new_vars = []
   for x, g, lr in utils.eqzip(xs, gs, learning_rates):
     if lr is None:
       lr = self.learning_rate
     if g is not None:
       new_vars.append((x * (1 - lr) - g * lr))
     else:
       new_vars.append(x)
   return new_vars, state
    def compute_next_state(self, outputs, meta_opt, previous_state):
        zs = outputs.zs
        xs = outputs.xs
        batch = outputs.batch
        mods = outputs.mods
        backward_mods = outputs.backward_mods
        variables = self.get_variables()

        rev_mods = mods[::-1]
        rev_backward_mods = backward_mods[::-1]
        rev_xs = xs[::-1]
        rev_zs = zs[::-1] + [None]

        to_top = xs[-1]

        # variables that change in the loop
        hs = []
        d = meta_opt.compute_top_delta(to_top)  # [bs x 32 x delta_channels]

        iterator = utils.eqzip(rev_backward_mods + [None], rev_mods + [None],
                               [None] + rev_mods, rev_xs, rev_zs)
        for (backward_mod, lower_mod, upper_mod, x, z) in iterator:
            w_bot = None
            if not lower_mod is None:
                w_bot = previous_state.variables[variables.index(lower_mod.w)]
            w_top = None
            if not upper_mod is None:
                w_top = previous_state.variables[variables.index(upper_mod.w)]
            backward_w = None
            if backward_mod is not None:
                backward_w = previous_state.variables[variables.index(
                    backward_mod.w)]
            if lower_mod is not None:
                bias = previous_state.variables[variables.index(lower_mod.b)]
            else:
                bias = tf.zeros([x.shape[1]])

            h, d = self.compute_next_h_d(meta_opt=meta_opt,
                                         w_bot=w_bot,
                                         w_top=w_top,
                                         bias=bias,
                                         backward_w=backward_w,
                                         x=x,
                                         z=z,
                                         d=d)
            hs.append(h)

        w_forward_var_idx = [variables.index(mod.w) for mod in rev_mods]
        w_backward_var_idx = [
            variables.index(mod.w) for mod in rev_backward_mods
        ]
        b_var_idx = [variables.index(mod.b) for mod in rev_mods]

        # storage location for outputs of below loop
        grads = [None for _ in previous_state.variables]

        # over-ride learning rate for perturbation variables
        learning_rate = [None for _ in previous_state.variables]

        # This is a map -- no state is shared cross loop
        for l_idx, w_forward_idx, w_backward_idx, b_idx, upper_h, lower_h, lower_x, upper_x in utils.eqzip(
                range(len(w_forward_var_idx)), w_forward_var_idx,
                w_backward_var_idx, b_var_idx, hs[:-1], hs[1:], xs[::-1][1:],
                xs[::-1][:-1]):

            b_base = previous_state.variables[b_idx]
            change_w_forward, change_b = self.weight_change_for_layer(
                meta_opt=meta_opt,
                l_idx=l_idx,
                w_base=previous_state.variables[w_forward_idx],
                b_base=b_base,
                upper_h=upper_h,
                lower_h=lower_h,
                upper_x=upper_x,
                lower_x=lower_x,
                prefix='forward',
                include_bias=True)

            if self.identical_updates:
                change_w_backward = change_w_forward
            else:
                change_w_backward = self.weight_change_for_layer(
                    meta_opt=meta_opt,
                    l_idx=l_idx,
                    w_base=previous_state.variables[w_backward_idx],
                    b_base=b_base,
                    upper_h=upper_h,
                    lower_h=lower_h,
                    upper_x=upper_x,
                    lower_x=lower_x,
                    prefix='backward',
                    include_bias=False)

            grads[w_forward_idx] = change_w_forward

            grads[w_backward_idx] = change_w_backward

            grads[b_idx] = change_b

        cur_transformer = common.transformer_at_state(self,
                                                      previous_state.variables)
        next_state = meta_opt.compute_next_state(
            grads,
            learning_rate=learning_rate,
            cur_state=previous_state,
            cur_transformer=lambda x: cur_transformer(x)[0])
        return next_state
Ejemplo n.º 7
0
  def compute_next_state(self, outputs, meta_opt, previous_state):
    zs = outputs.zs
    xs = outputs.xs
    batch = outputs.batch
    mods = outputs.mods
    backward_mods = outputs.backward_mods
    variables = self.get_variables()

    rev_mods = mods[::-1]
    rev_backward_mods = backward_mods[::-1]
    rev_xs = xs[::-1]
    rev_zs = zs[::-1] + [None]

    to_top = xs[-1]

    # variables that change in the loop
    hs = []
    d = meta_opt.compute_top_delta(to_top)  # [bs x 32 x delta_channels]

    iterator = utils.eqzip(rev_backward_mods + [None], rev_mods + [None],
                           [None] + rev_mods, rev_xs, rev_zs)
    for (backward_mod, lower_mod, upper_mod, x, z) in iterator:
      w_bot = None
      if not lower_mod is None:
        w_bot = previous_state.variables[variables.index(lower_mod.w)]
      w_top = None
      if not upper_mod is None:
        w_top = previous_state.variables[variables.index(upper_mod.w)]
      backward_w = None
      if backward_mod is not None:
        backward_w = previous_state.variables[variables.index(backward_mod.w)]
      if lower_mod is not None:
        bias = previous_state.variables[variables.index(lower_mod.b)]
      else:
        bias = tf.zeros([x.shape[1]])

      h, d = self.compute_next_h_d(
          meta_opt=meta_opt,
          w_bot=w_bot,
          w_top=w_top,
          bias=bias,
          backward_w=backward_w,
          x=x,
          z=z,
          d=d)
      hs.append(h)

    w_forward_var_idx = [variables.index(mod.w) for mod in rev_mods]
    w_backward_var_idx = [variables.index(mod.w) for mod in rev_backward_mods]
    b_var_idx = [variables.index(mod.b) for mod in rev_mods]

    # storage location for outputs of below loop
    grads = [None for _ in previous_state.variables]

    # over-ride learning rate for perturbation variables
    learning_rate = [None for _ in previous_state.variables]

    # This is a map -- no state is shared cross loop
    for l_idx, w_forward_idx, w_backward_idx, b_idx, upper_h, lower_h, lower_x, upper_x in utils.eqzip(
        range(len(w_forward_var_idx)), w_forward_var_idx, w_backward_var_idx,
        b_var_idx, hs[:-1], hs[1:], xs[::-1][1:], xs[::-1][:-1]):

      b_base = previous_state.variables[b_idx]
      change_w_forward, change_b = self.weight_change_for_layer(
          meta_opt=meta_opt,
          l_idx=l_idx,
          w_base=previous_state.variables[w_forward_idx],
          b_base=b_base,
          upper_h=upper_h,
          lower_h=lower_h,
          upper_x=upper_x,
          lower_x=lower_x,
          prefix='forward',
          include_bias=True)

      if self.identical_updates:
        change_w_backward = change_w_forward
      else:
        change_w_backward = self.weight_change_for_layer(
            meta_opt=meta_opt,
            l_idx=l_idx,
            w_base=previous_state.variables[w_backward_idx],
            b_base=b_base,
            upper_h=upper_h,
            lower_h=lower_h,
            upper_x=upper_x,
            lower_x=lower_x,
            prefix='backward',
            include_bias=False)

      grads[w_forward_idx] = change_w_forward

      grads[w_backward_idx] = change_w_backward

      grads[b_idx] = change_b

    cur_transformer = common.transformer_at_state(self,
                                                  previous_state.variables)
    next_state = meta_opt.compute_next_state(
        grads,
        learning_rate=learning_rate,
        cur_state=previous_state,
        cur_transformer=lambda x: cur_transformer(x)[0])
    return next_state