コード例 #1
0
 def _combine(self, concat, *argv):
   if concat:
     y = _concat(list(argv), axis=3)
   else:
     y = tuple(argv)
   return y
コード例 #2
0
  def _compute_gradients(self, cost):
    """Computes gradients.
    Args:
      cost: Loss function.
    Returns:
      grads_and_vars: List of tuple of gradients and variables.
    """
    config = self.config
    if not config.manual_gradients:
      return super(RevNetModel, self)._compute_gradients(cost)
    log.warning("Manually building gradient graph.")
    g = tf.get_default_graph()
    tf.get_variable_scope().reuse_variables()
    num_stages = len(self.config.num_residual_units)
    
    beta_final = tf.get_variable("unit_last/final_bn/beta")
    gamma_final = tf.get_variable("unit_last/final_bn/gamma")
   
    w_final = tf.get_variable("logit/w")
    b_final = tf.get_variable("logit/b")
    filters = [ff for ff in self.config.filters]  # Copy filter config.

    if config.use_bottleneck:
      res_func = self._bottleneck_residual_backward
      # For CIFAR-10 it's [16, 16, 32, 64] => [16, 64, 128, 256]
      for ii in range(1, len(filters)):
        filters[ii] *= 4
    else:
      res_func = self._residual_backward

    grads_list = []
    vars_list = []
    var_final = [beta_final, gamma_final, w_final, b_final]
   
    h1, h2 = self._saved_hidden[-1]
    h1, h2 = tf.stop_gradient(h1), tf.stop_gradient(h2)
    h = _concat([h1, h2], axis=3)
    with tf.variable_scope("unit_last"):
      h = self._batch_norm("final_bn", h, add_ops=False)
      h = self._relu("final_relu", h)

    h = self._global_avg_pool(h)
    with tf.variable_scope("logit"):
      logits = self._fully_connected(h, config.num_classes)
    with tf.variable_scope("costs"):
      xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=self.label)
      cost = tf.reduce_mean(xent, name="xent")

    _grads = tf.gradients(cost, [h1, h2] + var_final, gate_gradients=True)
    dh1, dh2 = _grads[0], _grads[1]
    _grads = _grads[2:]
    # Injected dependency.
    with tf.control_dependencies(_grads):
      h_grad = (tf.identity(dh1), tf.identity(dh2))
    grads_list.extend(_grads)
    # grads_list.extend(_grads[2:])
    vars_list.extend(var_final)

    h1, h2 = self._saved_hidden[-1]
    h1, h2 = tf.stop_gradient(h1), tf.stop_gradient(h2)
    h = (h1, h2)

    # New version, using single for-loop.
    ss = num_stages - 1
    ii = config.num_residual_units[ss] - 1
    nlayers = sum(config.num_residual_units)
    for ll in range(nlayers - 1, -1, -1):
      no_activation = False
      if ii == 0:
        in_filter = filters[ss]
        stride = self._stride_arr(self.config.strides[ss])
        if ss == 0:
          no_activation = True
      else:
        in_filter = filters[ss + 1]
        stride = self._stride_arr(1)
      out_filter = filters[ss + 1]

      with tf.variable_scope("unit_{}_{}".format(ss + 1, ii)):

        # Reconstruct input.
        if ii == 0:
          h = self._saved_hidden[ss]
        else:
          h = res_func(h, out_filter)

        # Rerun the layer, and get gradients.
        h_grad, w_list, w_grad = self._residual_grad(
            h,
            h_grad,
            in_filter,
            out_filter,
            stride,
            no_activation=no_activation)

        grads_list.extend(w_grad)
        vars_list.extend(w_list)

      # Counter.
      if ii == 0:
        ss -= 1
        ii = config.num_residual_units[ss] - 1
      else:
        ii -= 1

    h_grad = _concat(h_grad, axis=3)
    w_init = tf.get_variable("init/init_conv/w")
 
    beta_init = tf.get_variable("init/init_bn/beta")
    gamma_init = tf.get_variable("init/init_bn/gamma")
    var_init = [beta_init, gamma_init, w_init]
    _grads = tf.gradients(h, var_init, h_grad)
    grads_list.extend(_grads)
    vars_list.extend(var_init)

    # Add weight decay.
    def add_wd(x):
      g, w = x[0], x[1]
      assert self._wd_hidden > 0.0, "Not applying weight decay"
      if w.name.endswith("w:0") and self._wd_hidden > 0.0:
        log.info("Adding weight decay {:.4e} for variable {}".format(
            self._wd_hidden, x[1].name))
        return g + self._wd_hidden * w, w
      else:
        return g, w

    # Always gate gradients to avoid unwanted behaviour.
    return map(add_wd, zip(tf.tuple(grads_list), vars_list))